/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

namespace facebook::velox::wave {
bool registerHeader(const char* text);
const char* velox_experimental_wave_common_BitUtil_cuh =
    "velox/experimental/wave/common/BitUtil.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <stdint.h>\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "constexpr int32_t kWarpThreads = 32;\n"
    "\n"
    "template <typename T, typename U>\n"
    "__host__ __device__ constexpr inline T roundUp(T value, U factor) {\n"
    "  return (value + (factor - 1)) / factor * factor;\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "constexpr T __device__ __host__ lowMask(int bits) {\n"
    "  /****\n"
    "   * NVCC BUG: If the special case for all bits is not in, all modes except -G\n"
    "   * produce a 0 mask for 32 or 64 bits.\n"
    "   ****/\n"
    "  return bits == 8 * sizeof(T) ? ~static_cast<T>(0)\n"
    "                               : (static_cast<T>(1) << bits) - 1;\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "constexpr inline __device__ __host__ T highMask(int bits) {\n"
    "  return lowMask<T>(bits) << ((sizeof(T) * 8) - bits);\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline T* __device__ __host__ addBytes(T* ptr, int bytes) {\n"
    "  return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + bytes);\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline const T* __device__ __host__ addBytes(const T* ptr, int bytes) {\n"
    "  return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + bytes);\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline T* __device__ __host__ addCast(void* ptr, int bytes) {\n"
    "  return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + bytes);\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline const T* __device__ __host__ addCast(const void* ptr, int bytes) {\n"
    "  return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + bytes);\n"
    "}\n"
    "\n"
    "inline unsigned int __device__\n"
    "deviceScale32(unsigned int n, unsigned int scale) {\n"
    "  return (static_cast<unsigned long long>(static_cast<unsigned int>(n)) *\n"
    "          scale) >>\n"
    "      32;\n"
    "}\n"
    "\n"
    "__device__ __forceinline__ unsigned int LaneId() {\n"
    "  return threadIdx.x % kWarpThreads;\n"
    "}\n"
    "\n"
    "/* Log2 included from cub */\n"
    "/**\n"
    " * \\brief Statically determine log2(N), rounded up.\n"
    " *\n"
    " * For example:\n"
    " *     Log2<8>::VALUE   // 3\n"
    " *     Log2<3>::VALUE   // 2\n"
    " */\n"
    "template <int N, int CURRENT_VAL = N, int COUNT = 0>\n"
    "struct Log2 {\n"
    "  /// Static logarithm value\n"
    "  enum {\n"
    "    VALUE = Log2<N, (CURRENT_VAL >> 1), COUNT + 1>::VALUE\n"
    "  }; // Inductive case\n"
    "};\n"
    "\n"
    "template <int N, int COUNT>\n"
    "struct Log2<N, 0, COUNT> {\n"
    "  enum {\n"
    "    VALUE = (1 << (COUNT - 1) < N) ? // Base case\n"
    "        COUNT\n"
    "                                   : COUNT - 1\n"
    "  };\n"
    "};\n"
    "\n"
    "namespace detail {\n"
    "inline __device__ bool isLastInWarp() {\n"
    "  return (threadIdx.x & (kWarpThreads - 1)) == (kWarpThreads - 1);\n"
    "}\n"
    "} // namespace detail\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_BitUtil_cuh_reg =
    registerHeader(velox_experimental_wave_common_BitUtil_cuh);
const char* velox_experimental_wave_common_Scan_cuh =
    "velox/experimental/wave/common/Scan.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include \"velox/experimental/wave/common/BitUtil.cuh\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "template <typename T, int32_t kNumLanes = kWarpThreads>\n"
    "struct WarpScan {\n"
    "  enum {\n"
    "    /// Whether the logical warp size and the PTX warp size coincide\n"
    "    IS_ARCH_WARP = (kNumLanes == kWarpThreads),\n"
    "\n"
    "    /// The number of warp scan steps\n"
    "    STEPS = Log2<kNumLanes>::VALUE,\n"
    "  };\n"
    "\n"
    "  int laneId;\n"
    "\n"
    "  static constexpr unsigned int member_mask =\n"
    "      kNumLanes == 32 ? 0xffffffff : (1 << kNumLanes) - 1;\n"
    "\n"
    "  __device__ WarpScan() : laneId(LaneId()) {}\n"
    "\n"
    "  __device__ __forceinline__ void exclusiveSum(\n"
    "      T input, ///< [in] Calling thread's input item.\n"
    "      T& exclusive_output) ///< [out] Calling thread's output item.  May be\n"
    "                           ///< aliased with \\p input.\n"
    "  {\n"
    "    T initial_value = 0;\n"
    "    exclusiveSum(input, exclusive_output, initial_value);\n"
    "  }\n"
    "\n"
    "  __device__ __forceinline__ void exclusiveSum(\n"
    "      T input, ///< [in] Calling thread's input item.\n"
    "      T& exclusive_output, ///< [out] Calling thread's output item.  May be\n"
    "                           ///< aliased with \\p input.\n"
    "      T initial_value) {\n"
    "    T inclusive_output;\n"
    "    inclusiveSum(input, inclusive_output);\n"
    "\n"
    "    exclusive_output = initial_value + inclusive_output - input;\n"
    "  }\n"
    "\n"
    "  __device__ __forceinline__ void exclusiveSum(\n"
    "      T input,\n"
    "      T& exclusive_output,\n"
    "      T initial_value,\n"
    "      T& warp_aggregate) {\n"
    "    T inclusive_output;\n"
    "    inclusivesum(input, inclusive_output);\n"
    "    warp_aggregate = __shfl_sync(member_mask, inclusive_output, kNumLanes - 1);\n"
    "    exclusive_output = initial_value + inclusive_output - input;\n"
    "  }\n"
    "\n"
    "  __device__ __forceinline__ void inclusiveSum(T input, T& inclusive_output) {\n"
    "    inclusive_output = input;\n"
    "    if (IS_ARCH_WARP || (member_mask & (1 << laneId)) != 0) {\n"
    "#pragma unroll\n"
    "      for (int STEP = 0; STEP < STEPS; STEP++) {\n"
    "        int offset = (1 << STEP);\n"
    "        T other = __shfl_up_sync(member_mask, inclusive_output, offset);\n"
    "        if (laneId >= offset) {\n"
    "          inclusive_output += other;\n"
    "        }\n"
    "      }\n"
    "    }\n"
    "  }\n"
    "};\n"
    "\n"
    "template <typename T, int32_t kNumLanes = kWarpThreads>\n"
    "struct WarpReduce {\n"
    "  static constexpr int32_t STEPS = Log2<kNumLanes>::VALUE;\n"
    "\n"
    "  int laneId;\n"
    "\n"
    "  /// 32-thread physical warp member mask of logical warp\n"
    "\n"
    "  static constexpr unsigned int member_mask =\n"
    "      kNumLanes == 32 ? 0xffffffff : (1 << kNumLanes) - 1;\n"
    "\n"
    "  __device__ WarpReduce() : laneId(LaneId()) {}\n"
    "\n"
    "  template <typename Func>\n"
    "  __device__ __forceinline__ T reduce(T val, Func func) {\n"
    "    for (int32_t offset = kNumLanes / 2; offset > 0; offset = offset >> 1) {\n"
    "      T other = __shfl_down_sync(0xffffffff, val, offset);\n"
    "      if (laneId + offset < kNumLanes) {\n"
    "        val = func(val, other);\n"
    "      }\n"
    "    }\n"
    "    return val;\n"
    "  }\n"
    "};\n"
    "\n"
    "/// Returns the block wide exclusive sum (sum of 'input' for all\n"
    "/// lanes below threadIdx.x). If 'total' is non-nullptr, the block\n"
    "/// wide sum is returned in '*total'. 'temp' must have\n"
    "/// exclusiveSumTempSize() writable bytes aligned for T.\n"
    "template <typename T, int32_t kBlockSize>\n"
    "inline __device__ T exclusiveSum(T input, T* total, T* temp) {\n"
    "  constexpr int32_t kNumWarps = kBlockSize / kWarpThreads;\n"
    "  using Scan = WarpScan<T>;\n"
    "  T sum;\n"
    "  Scan().exclusiveSum(input, sum);\n"
    "  if (kBlockSize == kWarpThreads) {\n"
    "    if (total) {\n"
    "      if (threadIdx.x == kWarpThreads - 1) {\n"
    "        *total = input + sum;\n"
    "      }\n"
    "      __syncthreads();\n"
    "    }\n"
    "    return sum;\n"
    "  }\n"
    "  if (detail::isLastInWarp()) {\n"
    "    temp[threadIdx.x / kWarpThreads] = input + sum;\n"
    "  }\n"
    "  __syncthreads();\n"
    "  using InnerScan = WarpScan<T, kNumWarps>;\n"
    "  T warpSum = threadIdx.x < kNumWarps ? temp[threadIdx.x] : 0;\n"
    "  T blockSum;\n"
    "  InnerScan().exclusiveSum(warpSum, blockSum);\n"
    "  if (threadIdx.x < kNumWarps) {\n"
    "    temp[threadIdx.x] = blockSum;\n"
    "    if (total && threadIdx.x == kNumWarps - 1) {\n"
    "      *total = warpSum + blockSum;\n"
    "    }\n"
    "  }\n"
    "  __syncthreads();\n"
    "  return sum + temp[threadIdx.x / kWarpThreads];\n"
    "}\n"
    "\n"
    "/// Returns the block wide inclusive sum (sum of 'input' for all\n"
    "/// lanes below threadIdx.x). 'temp' must have\n"
    "/// exclusiveSumTempSize() writable bytes aligned for T. '*total' is set to the\n"
    "/// TB-wide total if 'total' is not nullptr.\n"
    "template <typename T, int32_t kBlockSize>\n"
    "inline __device__ T inclusiveSum(T input, T* total, T* temp) {\n"
    "  constexpr int32_t kNumWarps = kBlockSize / kWarpThreads;\n"
    "  using Scan = WarpScan<T>;\n"
    "  T sum;\n"
    "  Scan().inclusiveSum(input, sum);\n"
    "  if (kBlockSize <= kWarpThreads) {\n"
    "    if (total != nullptr) {\n"
    "      if (threadIdx.x == kBlockSize - 1) {\n"
    "        *total = sum;\n"
    "      }\n"
    "      __syncthreads();\n"
    "    }\n"
    "    return sum;\n"
    "  }\n"
    "  if (detail::isLastInWarp()) {\n"
    "    temp[threadIdx.x / kWarpThreads] = sum;\n"
    "  }\n"
    "  __syncthreads();\n"
    "  constexpr int32_t kInnerWidth = kNumWarps < 2 ? 2 : kNumWarps;\n"
    "  using InnerScan = WarpScan<T, kInnerWidth>;\n"
    "  T warpSum = threadIdx.x < kInnerWidth ? temp[threadIdx.x] : 0;\n"
    "  T blockSum;\n"
    "  InnerScan().exclusiveSum(warpSum, blockSum);\n"
    "  if (threadIdx.x < kInnerWidth) {\n"
    "    temp[threadIdx.x] = blockSum;\n"
    "  }\n"
    "  if (total != nullptr && threadIdx.x == kInnerWidth - 1) {\n"
    "    *total = blockSum + warpSum;\n"
    "  }\n"
    "  __syncthreads();\n"
    "  return sum + temp[threadIdx.x / kWarpThreads];\n"
    "}\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_Scan_cuh_reg =
    registerHeader(velox_experimental_wave_common_Scan_cuh);
const char* velox_experimental_wave_vector_Operand_h =
    "velox/experimental/wave/vector/Operand.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <stdint.h>\n"
    "\n"
    "/// Types for device side access to device vectors. Separate header\n"
    "/// independent of Velox headers, included for both host and device\n"
    "/// side files.\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "/// Copy of TypeKind in velox/type/Type.h. Type.h is incompatible with Cuda\n"
    "/// headers, therefore duplicated here.\n"
    "enum class WaveTypeKind : int8_t {\n"
    "\n"
    "  BOOLEAN = 0,\n"
    "  TINYINT = 1,\n"
    "  SMALLINT = 2,\n"
    "  INTEGER = 3,\n"
    "  BIGINT = 4,\n"
    "  REAL = 5,\n"
    "  DOUBLE = 6,\n"
    "  VARCHAR = 7,\n"
    "  VARBINARY = 8,\n"
    "  TIMESTAMP = 9,\n"
    "  HUGEINT = 10,\n"
    "  // Enum values for ComplexTypes start after 30 to leave\n"
    "  // some values space to accommodate adding new scalar/native\n"
    "  // types above.\n"
    "  ARRAY = 30,\n"
    "  MAP = 31,\n"
    "  ROW = 32,\n"
    "  UNKNOWN = 33,\n"
    "  FUNCTION = 34,\n"
    "  OPAQUE = 35,\n"
    "  INVALID = 36\n"
    "};\n"
    "\n"
    "template <typename T>\n"
    "struct WaveTypeTrait {};\n"
    "\n"
    "template <>\n"
    "struct WaveTypeTrait<int32_t> {\n"
    "  static constexpr WaveTypeKind typeKind = WaveTypeKind::INTEGER;\n"
    "};\n"
    "\n"
    "template <>\n"
    "struct WaveTypeTrait<uint32_t> {\n"
    "  static constexpr WaveTypeKind typeKind = WaveTypeKind::INTEGER;\n"
    "};\n"
    "\n"
    "template <>\n"
    "struct WaveTypeTrait<int64_t> {\n"
    "  static constexpr WaveTypeKind typeKind = WaveTypeKind::BIGINT;\n"
    "};\n"
    "template <>\n"
    "struct WaveTypeTrait<uint64_t> {\n"
    "  static constexpr WaveTypeKind typeKind = WaveTypeKind::BIGINT;\n"
    "};\n"
    "\n"
    "// Normal thread block size for Wave kernels\n"
    "constexpr int32_t kBlockSize = 256;\n"
    "using OperandId = int32_t;\n"
    "\n"
    "constexpr OperandId kNoOperand = ~0;\n"
    "\n"
    "using OperandIndex = uint16_t;\n"
    "constexpr OperandIndex kEmpty = ~0;\n"
    "\n"
    "// operand indices above this are offsets into TB shared memory arrays.\n"
    "constexpr OperandIndex kMinSharedMemIndex = 0x8000;\n"
    "\n"
    "// Number of nullable locals in shared memory. Each has kBlockSize null bytes at\n"
    "// the start of the TB shared memory. 0 means no nulls. 1 means first kBlockSize\n"
    "// bytes are nulls, 2 means second kBlockSize  bytes are null flags etc.\n"
    "constexpr uint16_t kSharedNullMask = 3;\n"
    "\n"
    "/// Start of the parameter array in the TB shared memory. 13 bits. Shift 1 left\n"
    "/// to get offset.\n"
    "constexpr uint16_t kSharedOperandMask = 0x7ffc;\n"
    "\n"
    "/// Describes an operand for a Wave kernel instruction. The same\n"
    "/// insttruction is interpreted by multiple thread blocks in the\n"
    "/// kernel invocation. When accessing an operand, we have the base\n"
    "/// index of the thread block. This is blockIdx.x * blockDim.x if all\n"
    "/// thread blocks run the same instructions. When the blocks run\n"
    "/// different instruction streams, the base is (blockIdx.x - <index of\n"
    "/// first block with this instruction stream>) * blockDim.x. We also have a\n"
    "/// shared memory pointer to thread block shared memory. Some operands may come\n"
    "/// from thread block shared memory.\n"
    "\n"
    "constexpr uint8_t kNull = 0;\n"
    "constexpr uint8_t kNotNull = 255;\n"
    "\n"
    "/// Indicates a null value introduced by wrap in 'indices'.\n"
    "constexpr int32_t kNullIndex = -1;\n"
    "\n"
    "struct Operand {\n"
    "  static constexpr int32_t kPointersInOperand = 4;\n"
    "\n"
    "  int32_t indexMask;\n"
    "\n"
    "  int32_t size;\n"
    "\n"
    "  // Array of flat base values. Cast to pod type or StringView.\n"
    "  void* base;\n"
    "\n"
    "  // Array of null indicators. No nulls if nullptr.  A 1 means not-null, for\n"
    "  // consistency with Velox.\n"
    "  uint8_t* nulls;\n"
    "\n"
    "  // If non-nullptr, provides index into 'base. Subscripted with the\n"
    "  // blockIdx - idx of first bllock wit this instruction\n"
    "  // stream. Different thread blocks may or may not have indices for\n"
    "  // a given operand.\n"
    "  int32_t** indices;\n"
    "};\n"
    "\n"
    "/// Per-lane error code.\n"
    "enum class ErrorCode : uint8_t {\n"
    "  // All operations completed.\n"
    "  kOk = 0,\n"
    "\n"
    "  // Set on entry when continuing, e.g. produce more data from hash probe.\n"
    "  kContinue,\n"
    "\n"
    "  // all codes from here onwards mean the lane is off\n"
    "  // Catchall for runtime errors.\n"
    "  kError,\n"
    "\n"
    "  kInsufficientMemory,\n"
    "\n"
    "  kInactive,\n"
    "\n"
    "};\n"
    "\n"
    "/// Thread block status with count of active lanes and a per lane\n"
    "/// error code and continue points for operators that can produce more\n"
    "/// data.\n"
    "struct BlockStatus {\n"
    "  int32_t numRows{0};\n"
    "  ErrorCode errors[kBlockSize];\n"
    "};\n"
    "\n"
    "/// User error status returned from kernels. Represents one error from\n"
    "/// an arbitrary kernel thread with an error.\n"
    "struct KernelError {\n"
    "  static constexpr uintptr_t kNoParam = 0;\n"
    "  static constexpr uintptr_t kStringParam = 1UL << 60;\n"
    "  static constexpr uintptr_t kInt64Param = 2UL << 60;\n"
    "\n"
    "  /// Host addressable constant string with error message. nullptr if\n"
    "  /// no error message. The 4 high bits of the pointer indicate if the\n"
    "  /// message is compleemented by 'number' (kInt64Param) or\n"
    "  /// 'ptr'kStringParam) (k. If kNoParam the string is the only error\n"
    "  /// info.\n"
    "  int32_t messageEnum{0};\n"
    "  int64_t extra;\n"
    "};\n"
    "\n"
    "/// Describes the location of an instruction's return state in the\n"
    "/// BlockStatus area. The return states are allocated right above\n"
    "/// the BlockStatus array. First are grid level statuses for instructions that\n"
    "/// return a status. After this are block level statuses.\n"
    "\n"
    "struct InstructionStatus {\n"
    "  // Offset of containing instruction's grid state from the end of BlockStatus\n"
    "  // array.\n"
    "  uint16_t gridState{0};\n"
    "  // Total size of gridStates. Block level states start after the last grid\n"
    "  // state.\n"
    "  uint16_t gridStateSize{0};\n"
    "  // Start of per-block status. gridStateSize + numBlocks *\n"
    "  // blockState' is the  offset of the first per block status from the end of\n"
    "  // BlockStatus array.\n"
    "  uint16_t blockState{0};\n"
    "};\n"
    "\n"
    "/// Returns the number of active rows in 'status' for 'numBlocks'.\n"
    "#ifndef __CUDACC_RTC__\n"
    "int32_t statusNumRows(const BlockStatus* status, int32_t numBlocks);\n"
    "#endif\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_vector_Operand_h_reg =
    registerHeader(velox_experimental_wave_vector_Operand_h);
const char* velox_experimental_wave_exec_ErrorCode_h =
    "velox/experimental/wave/exec/ErrorCode.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include \"velox/experimental/wave/vector/Operand.h\"\n"
    "\n"
    "namespace facebook::velox::wave {} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_exec_ErrorCode_h_reg =
    registerHeader(velox_experimental_wave_exec_ErrorCode_h);
const char* velox_experimental_wave_exec_WaveCore_cuh =
    "velox/experimental/wave/exec/WaveCore.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "#pragma once\n"
    "\n"
    "#include \"velox/experimental/wave/common/Scan.cuh\"\n"
    "#include \"velox/experimental/wave/exec/ExprKernel.h\"\n"
    "#include \"velox/experimental/wave/vector/Operand.h\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "template <typename T>\n"
    "inline T* __device__ gridStatus(const WaveShared* shared, int32_t gridState) {\n"
    "  return reinterpret_cast<T*>(\n"
    "      roundUp(\n"
    "          reinterpret_cast<uintptr_t>(\n"
    "              &shared->status\n"
    "                   [shared->numBlocks - (shared->blockBase / kBlockSize)]),\n"
    "          8) +\n"
    "      gridState);\n"
    "}\n"
    "\n"
    "inline __device__ void setError(\n"
    "    WaveShared* shared,\n"
    "    ErrorCode& laneStatus,\n"
    "    bool insideTry,\n"
    "    int32_t messageEnum,\n"
    "    int64_t extra = 0) {\n"
    "  laneStatus = ErrorCode::kError;\n"
    "  if (insideTry) {\n"
    "    return;\n"
    "  }\n"
    "  auto* error = gridStatus<KernelError>(shared, 0);\n"
    "  error->messageEnum = messageEnum;\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline T* __device__\n"
    "gridStatus(const WaveShared* shared, const InstructionStatus& status) {\n"
    "  return gridStatus<T>(shared, status.gridState);\n"
    "}\n"
    "\n"
    "/// Returns a pointer to the first byte of block level status in the extra\n"
    "/// status area above the BlockStatus array for the current block.\n"
    "template <typename T>\n"
    "inline T* __device__ blockStatus(\n"
    "    const WaveShared* shared,\n"
    "    int32_t gridStatusSize,\n"
    "    int32_t blockStatusOffset,\n"
    "    int32_t blockStatusSize = sizeof(T)) {\n"
    "  return reinterpret_cast<T*>(\n"
    "      roundUp(\n"
    "          reinterpret_cast<uintptr_t>(shared->status) +\n"
    "              shared->numBlocks * sizeof(BlockStatus),\n"
    "          8) +\n"
    "      gridStatusSize + (blockStatusOffset * shared->numBlocks) +\n"
    "      (blockStatusSize * (shared->blockBase / kBlockSize)));\n"
    "}\n"
    "\n"
    "inline bool __device__ laneActive(ErrorCode code) {\n"
    "  return static_cast<uint8_t>(code) <=\n"
    "      static_cast<uint8_t>(ErrorCode::kContinue);\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "__device__ inline T& flatValue(void* base, int32_t blockBase) {\n"
    "  return reinterpret_cast<T*>(base)[blockBase + threadIdx.x];\n"
    "}\n"
    "\n"
    "/// Returns true if operand is non null. Sets 'value' to the value of the\n"
    "/// operand.\n"
    "template <typename T>\n"
    "__device__ __forceinline__ bool operandOrNull(\n"
    "    Operand** operands,\n"
    "    OperandIndex opIdx,\n"
    "    int32_t blockBase,\n"
    "    T& value) {\n"
    "  auto op = operands[opIdx];\n"
    "  auto index = threadIdx.x;\n"
    "  if (auto indicesInOp = op->indices) {\n"
    "    auto indices = indicesInOp[blockBase / kBlockSize];\n"
    "    if (indices) {\n"
    "      index = indices[index];\n"
    "      if (index == kNullIndex) {\n"
    "        return false;\n"
    "      }\n"
    "    } else {\n"
    "      index += blockBase;\n"
    "    }\n"
    "  } else {\n"
    "    index = (index + blockBase) & op->indexMask;\n"
    "  }\n"
    "  if (op->nulls && op->nulls[index] == kNull) {\n"
    "    return false;\n"
    "  }\n"
    "  value = reinterpret_cast<const T*>(op->base)[index];\n"
    "  return true;\n"
    "}\n"
    "\n"
    "template <bool kMayWrap, typename T>\n"
    "bool __device__ __forceinline__ valueOrNull(\n"
    "    Operand** operands,\n"
    "    OperandIndex opIdx,\n"
    "    int32_t blockBase,\n"
    "    T& value) {\n"
    "  auto op = operands[opIdx];\n"
    "  auto index = threadIdx.x;\n"
    "  if (!kMayWrap) {\n"
    "    index = (index + blockBase) & op->indexMask;\n"
    "    if (op->nulls && op->nulls[index] == kNull) {\n"
    "      return false;\n"
    "    }\n"
    "    value = reinterpret_cast<const T*>(op->base)[index];\n"
    "    return true;\n"
    "  }\n"
    "  if (auto indicesInOp = op->indices) {\n"
    "    auto indices = indicesInOp[blockBase / kBlockSize];\n"
    "    if (indices) {\n"
    "      index = indices[index];\n"
    "      if (index == kNullIndex) {\n"
    "        return false;\n"
    "      }\n"
    "    } else {\n"
    "      index += blockBase;\n"
    "    }\n"
    "  } else {\n"
    "    index = (index + blockBase) & op->indexMask;\n"
    "  }\n"
    "  if (op->nulls && op->nulls[index] == kNull) {\n"
    "    return false;\n"
    "  }\n"
    "  value = reinterpret_cast<const T*>(op->base)[index];\n"
    "  return true;\n"
    "}\n"
    "\n"
    "template <bool kMayWrap, typename T>\n"
    "void __device__ __forceinline__ loadValueOrNull(\n"
    "    Operand** operands,\n"
    "    OperandIndex opIdx,\n"
    "    int32_t blockBase,\n"
    "    T& value,\n"
    "    uint32_t& nulls) {\n"
    "  nulls = (nulls & ~(1U << (opIdx & 31))) |\n"
    "      (static_cast<uint32_t>(\n"
    "           valueOrNull<kMayWrap>(operands, opIdx, blockBase, value))\n"
    "       << (opIdx & 31));\n"
    "}\n"
    "\n"
    "template <bool kMayWrap, typename T>\n"
    "T __device__ __forceinline__\n"
    "nonNullOperand(Operand** operands, OperandIndex opIdx, int32_t blockBase) {\n"
    "  auto op = operands[opIdx];\n"
    "  auto index = threadIdx.x;\n"
    "  if (!kMayWrap) {\n"
    "    index = (index + blockBase) & op->indexMask;\n"
    "    return reinterpret_cast<const T*>(op->base)[index];\n"
    "  }\n"
    "  if (auto indicesInOp = op->indices) {\n"
    "    auto indices = indicesInOp[blockBase / kBlockSize];\n"
    "    if (indices) {\n"
    "      index = indices[index];\n"
    "    } else {\n"
    "      index += blockBase;\n"
    "    }\n"
    "  } else {\n"
    "    index = (index + blockBase) & op->indexMask;\n"
    "  }\n"
    "  return reinterpret_cast<const T*>(op->base)[index];\n"
    "}\n"
    "\n"
    "bool __device__ __forceinline__\n"
    "setRegisterNull(uint32_t& flags, int8_t bit, bool isNull) {\n"
    "  if (isNull) {\n"
    "    flags &= ~(1 << bit);\n"
    "  }\n"
    "  return isNull;\n"
    "}\n"
    "\n"
    "bool __device__ __forceinline__ isRegisterNull(uint32_t flags, int8_t bit) {\n"
    "  return 0 == (flags & (1 << bit));\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "__device__ inline T&\n"
    "flatOperand(Operand** operands, OperandIndex opIdx, int32_t blockBase) {\n"
    "  auto* op = operands[opIdx];\n"
    "  if (op->nulls) {\n"
    "    op->nulls[blockBase + threadIdx.x] = kNotNull;\n"
    "  }\n"
    "  return reinterpret_cast<T*>(op->base)[blockBase + threadIdx.x];\n"
    "}\n"
    "\n"
    "/// Clears 'bit' from 'flags' if notNull is false. Returns true if bit cleared.\n"
    "bool __device__ __forceinline__\n"
    "setNullRegister(uint32_t& flags, int8_t bit, bool notNull) {\n"
    "  if (!notNull) {\n"
    "    flags &= ~(1 << bit);\n"
    "  }\n"
    "  return !notNull;\n"
    "}\n"
    "\n"
    "/// Sets the lane's result to null for opIdx.\n"
    "__device__ inline void\n"
    "resultNull(Operand** operands, OperandIndex opIdx, int32_t blockBase) {\n"
    "  auto* op = operands[opIdx];\n"
    "  op->nulls[blockBase + threadIdx.x] = kNull;\n"
    "}\n"
    "\n"
    "__device__ inline void setNull(\n"
    "    Operand** operands,\n"
    "    OperandIndex opIdx,\n"
    "    int32_t blockBase,\n"
    "    bool isNull) {\n"
    "  auto* op = operands[opIdx];\n"
    "  op->nulls[blockBase + threadIdx.x] = isNull ? kNull : kNotNull;\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "__device__ inline T&\n"
    "flatResult(Operand** operands, OperandIndex opIdx, int32_t blockBase) {\n"
    "  auto* op = operands[opIdx];\n"
    "  if (op->nulls) {\n"
    "    op->nulls[blockBase + threadIdx.x] = kNotNull;\n"
    "  }\n"
    "  return reinterpret_cast<T*>(op->base)[blockBase + threadIdx.x];\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "__device__ inline T& flatResult(Operand* op, int32_t blockBase) {\n"
    "  return flatResult<T>(&op, 0, blockBase);\n"
    "}\n"
    "\n"
    "#define GENERATED_PREAMBLE(blockOffset)                                        \\\n"
    "  extern __shared__ char sharedChar[];                                         \\\n"
    "  WaveShared* shared = reinterpret_cast<WaveShared*>(sharedChar);              \\\n"
    "  if (threadIdx.x == 0) {                                                      \\\n"
    "    shared->operands = params.operands[0];                                     \\\n"
    "    shared->numBlocks = params.numBlocks;                                      \\\n"
    "    shared->numRowsPerThread = params.numRowsPerThread;                        \\\n"
    "    auto startBlock = blockIdx.x * shared->numRowsPerThread;                   \\\n"
    "    auto branchIdx = startBlock / shared->numBlocks;                           \\\n"
    "    startBlock = startBlock - (shared->numBlocks * branchIdx);                 \\\n"
    "    shared->programIdx = branchIdx;                                            \\\n"
    "    startBlock =                                                               \\\n"
    "        (startBlock / shared->numRowsPerThread) * shared->numRowsPerThread;    \\\n"
    "    int32_t numBlocksAbove = shared->numBlocks - startBlock;                   \\\n"
    "    if (numBlocksAbove < shared->numRowsPerThread) {                           \\\n"
    "      shared->numRowsPerThread = numBlocksAbove;                               \\\n"
    "    }                                                                          \\\n"
    "    shared->status = &params.status[startBlock];                               \\\n"
    "    shared->numRows = shared->status->numRows;                                 \\\n"
    "    shared->blockBase = startBlock * kBlockSize;                               \\\n"
    "    shared->states = params.operatorStates[0];                                 \\\n"
    "    shared->nthBlock = 0;                                                      \\\n"
    "    shared->streamIdx = params.streamIdx;                                      \\\n"
    "    shared->localContinue = false;                                             \\\n"
    "    shared->isContinue = params.startPC != nullptr;                            \\\n"
    "    if (shared->isContinue) {                                                  \\\n"
    "      shared->startLabel = params.startPC[shared->programIdx];                 \\\n"
    "    } else {                                                                   \\\n"
    "      shared->startLabel = -1;                                                 \\\n"
    "    }                                                                          \\\n"
    "    shared->extraWraps = params.extraWraps;                                    \\\n"
    "    shared->numExtraWraps = params.numExtraWraps;                              \\\n"
    "    shared->hasContinue = false;                                               \\\n"
    "    shared->stop = false;                                                      \\\n"
    "  }                                                                            \\\n"
    "  __syncthreads();                                                             \\\n"
    "  int32_t blockBase;                                                           \\\n"
    "  auto operands = shared->operands;                                            \\\n"
    "  ErrorCode laneStatus;                                                        \\\n"
    "  nextBlock:                                                                   \\\n"
    "  blockBase = shared->blockBase;                                               \\\n"
    "  if (!shared->isContinue) {                                                   \\\n"
    "    laneStatus =                                                               \\\n"
    "        threadIdx.x < shared->numRows ? ErrorCode::kOk : ErrorCode::kInactive; \\\n"
    "  } else {                                                                     \\\n"
    "    laneStatus = shared->status->errors[threadIdx.x];                          \\\n"
    "  }\n"
    "\n"
    "#define PROGRAM_EPILOGUE()                                \\\n"
    "  shared->status->errors[threadIdx.x] = laneStatus;       \\\n"
    "  __syncthreads();                                        \\\n"
    "  if (threadIdx.x == 0) {                                 \\\n"
    "    shared->status->numRows = shared->numRows;            \\\n"
    "    if (++shared->nthBlock >= shared->numRowsPerThread) { \\\n"
    "      shared->stop = true;                                \\\n"
    "    } else {                                              \\\n"
    "      ++shared->status;                                   \\\n"
    "      shared->numRows = shared->status->numRows;          \\\n"
    "      shared->blockBase += kBlockSize;                    \\\n"
    "    }                                                     \\\n"
    "  }                                                       \\\n"
    "  __syncthreads();                                        \\\n"
    "  if (!shared->stop) {                                    \\\n"
    "    goto nextBlock;                                       \\\n"
    "  }\n"
    "\n"
    "__device__ __forceinline__ void filterKernel(\n"
    "    bool flag,\n"
    "    Operand** operands,\n"
    "    OperandIndex indicesIdx,\n"
    "    int32_t blockBase,\n"
    "    WaveShared* shared,\n"
    "    ErrorCode& laneStatus) {\n"
    "  bool isPassed = flag && laneActive(laneStatus);\n"
    "  uint32_t bits = __ballot_sync(0xffffffff, isPassed);\n"
    "  if ((threadIdx.x & (kWarpThreads - 1)) == 0) {\n"
    "    reinterpret_cast<int32_t*>(&shared->data)[threadIdx.x / kWarpThreads] =\n"
    "        __popc(bits);\n"
    "  }\n"
    "  __syncthreads();\n"
    "  if (threadIdx.x < kWarpThreads) {\n"
    "    constexpr int32_t kNumWarps = kBlockSize / kWarpThreads;\n"
    "    uint32_t cnt = threadIdx.x < kNumWarps\n"
    "        ? reinterpret_cast<int32_t*>(&shared->data)[threadIdx.x]\n"
    "        : 0;\n"
    "    uint32_t sum;\n"
    "    using Scan = WarpScan<uint32_t, kBlockSize / kWarpThreads>;\n"
    "    Scan().exclusiveSum(cnt, sum);\n"
    "    if (threadIdx.x < kNumWarps) {\n"
    "      if (threadIdx.x == kNumWarps - 1) {\n"
    "        shared->numRows = cnt + sum;\n"
    "      }\n"
    "      reinterpret_cast<int32_t*>(&shared->data)[threadIdx.x] = sum;\n"
    "    }\n"
    "  }\n"
    "  __syncthreads();\n"
    "  if (bits & (1 << (threadIdx.x & (kWarpThreads - 1)))) {\n"
    "    auto* indices = reinterpret_cast<int32_t*>(operands[indicesIdx]->base);\n"
    "    auto start = blockBase +\n"
    "        reinterpret_cast<int32_t*>(&shared->data)[threadIdx.x / kWarpThreads];\n"
    "    auto bit = start +\n"
    "        __popc(bits & lowMask<uint32_t>(threadIdx.x & (kWarpThreads - 1)));\n"
    "    indices[bit] = blockBase + threadIdx.x;\n"
    "  }\n"
    "  laneStatus =\n"
    "      threadIdx.x < shared->numRows ? ErrorCode::kOk : ErrorCode::kInactive;\n"
    "  __syncthreads();\n"
    "}\n"
    "\n"
    "__device__ void __forceinline__ wrapKernel(\n"
    "    const OperandIndex* wraps,\n"
    "    int32_t numWraps,\n"
    "    OperandIndex indicesIdx,\n"
    "    Operand** operands,\n"
    "    int32_t blockBase,\n"
    "    WaveShared* shared) {\n"
    "  Operand* op = operands[indicesIdx];\n"
    "  auto* filterIndices = reinterpret_cast<int32_t*>(op->base);\n"
    "  if (filterIndices[blockBase + shared->numRows - 1] ==\n"
    "      shared->numRows + blockBase - 1) {\n"
    "    // There is no cardinality change.\n"
    "    if (threadIdx.x == 0) {\n"
    "      auto* op = operands[wraps[0]];\n"
    "      op->indices[blockBase / kBlockSize] = nullptr;\n"
    "    }\n"
    "    __syncthreads();\n"
    "    return;\n"
    "  }\n"
    "\n"
    "  struct WrapState {\n"
    "    int32_t* indices;\n"
    "  };\n"
    "\n"
    "  auto* state = reinterpret_cast<WrapState*>(&shared->data);\n"
    "  bool rowActive = threadIdx.x < shared->numRows;\n"
    "  int32_t totalWrap = numWraps + shared->numExtraWraps;\n"
    "  for (auto column = 0; column < totalWrap; ++column) {\n"
    "    if (threadIdx.x == 0) {\n"
    "      auto opIndex = column < numWraps ? wraps[column]\n"
    "                                       : shared->extraWraps + column - numWraps;\n"
    "      auto* op = operands[opIndex];\n"
    "      int32_t** opIndices = &op->indices[blockBase / kBlockSize];\n"
    "      // If there is no indirection or if this is column 0 whose indirection is\n"
    "      // inited here, use the filter rows.\n"
    "      if (!*opIndices || column == 0) {\n"
    "        *opIndices = filterIndices + blockBase;\n"
    "        state->indices = nullptr;\n"
    "      } else {\n"
    "        state->indices = *opIndices;\n"
    "      }\n"
    "    }\n"
    "    __syncthreads();\n"
    "    // Every thread sees the decision on thred 0 above.\n"
    "    if (!state->indices) {\n"
    "      continue;\n"
    "    }\n"
    "    int32_t newIndex;\n"
    "    if (rowActive) {\n"
    "      newIndex =\n"
    "          state->indices[filterIndices[blockBase + threadIdx.x] - blockBase];\n"
    "    }\n"
    "    // All threads hit this.\n"
    "    __syncthreads();\n"
    "    if (rowActive) {\n"
    "      state->indices[threadIdx.x] = newIndex;\n"
    "    }\n"
    "  }\n"
    "  __syncthreads();\n"
    "}\n"
    "\n"
    "__device__ void __forceinline__ wrapKernel(\n"
    "    const OperandIndex* wraps,\n"
    "    int32_t numWraps,\n"
    "    OperandIndex indicesIdx,\n"
    "    Operand** operands,\n"
    "    Operand** newWraps,\n"
    "    Operand** backup,\n"
    "    int32_t blockBase,\n"
    "    WaveShared* shared) {\n"
    "  Operand* op = operands[indicesIdx];\n"
    "  auto* filterIndices = reinterpret_cast<int32_t*>(op->base);\n"
    "  if (filterIndices[blockBase + shared->numRows - 1] ==\n"
    "      shared->numRows + blockBase - 1) {\n"
    "    // There is no cardinality change.\n"
    "    if (threadIdx.x == 0) {\n"
    "      auto* op = operands[wraps[0]];\n"
    "      op->indices[blockBase / kBlockSize] = nullptr;\n"
    "    }\n"
    "    __syncthreads();\n"
    "    return;\n"
    "  }\n"
    "\n"
    "  struct WrapState {\n"
    "    int32_t* indices;\n"
    "  };\n"
    "\n"
    "  auto* state = reinterpret_cast<WrapState*>(&shared->data);\n"
    "  bool rowActive = threadIdx.x < shared->numRows;\n"
    "  int32_t totalWrap = numWraps + shared->numExtraWraps;\n"
    "  for (auto column = 0; column < totalWrap; ++column) {\n"
    "    if (threadIdx.x == 0) {\n"
    "      auto opIndex = column < numWraps ? wraps[column]\n"
    "                                       : shared->extraWraps + column - numWraps;\n"
    "      auto* op = operands[opIndex];\n"
    "      int32_t** opIndices = &op->indices[blockBase / kBlockSize];\n"
    "      // If there is no indirection or if this is column 0 whose indirection is\n"
    "      // inited here, use the filter rows.\n"
    "      if (!*opIndices || column == 0) {\n"
    "        *opIndices = filterIndices + blockBase;\n"
    "        state->indices = nullptr;\n"
    "      } else {\n"
    "        state->indices = *opIndices;\n"
    "      }\n"
    "    }\n"
    "    __syncthreads();\n"
    "    // Every thread sees the decision on thred 0 above.\n"
    "    if (!state->indices) {\n"
    "      continue;\n"
    "    }\n"
    "    int32_t newIndex;\n"
    "    if (rowActive) {\n"
    "      newIndex =\n"
    "          state->indices[filterIndices[blockBase + threadIdx.x] - blockBase];\n"
    "    }\n"
    "    // All threads hit this.\n"
    "    __syncthreads();\n"
    "    if (rowActive) {\n"
    "      state->indices[threadIdx.x] = newIndex;\n"
    "    }\n"
    "  }\n"
    "  __syncthreads();\n"
    "}\n"
    "\n"
    "__device__ void __forceinline__ wrapKernel(\n"
    "    OperandIndex first,\n"
    "    const OperandIndex* wraps,\n"
    "    const OperandIndex* newIndices,\n"
    "    const OperandIndex* backups,\n"
    "    int32_t numWraps,\n"
    "    OperandIndex indicesIdx,\n"
    "    WaveShared* shared) {\n"
    "  auto* operands = shared->operands;\n"
    "  Operand* op = operands[indicesIdx];\n"
    "  auto* filterIndices = reinterpret_cast<int32_t*>(op->base);\n"
    "\n"
    "  struct WrapState {\n"
    "    int32_t* indices;\n"
    "    int32_t* newIndices;\n"
    "  };\n"
    "\n"
    "  auto* state = reinterpret_cast<WrapState*>(&shared->data);\n"
    "  bool rowActive = threadIdx.x < shared->numRows;\n"
    "\n"
    "  if (first != kEmpty) {\n"
    "    if (threadIdx.x == 0) {\n"
    "      operands[first]->indices[shared->blockBase / kBlockSize] =\n"
    "          filterIndices + shared->blockBase;\n"
    "    }\n"
    "  }\n"
    "\n"
    "  for (auto column = 0; column < numWraps; ++column) {\n"
    "    if (threadIdx.x == 0) {\n"
    "      auto nthBlock = shared->blockBase / kBlockSize;\n"
    "      auto opIndex = wraps[column];\n"
    "      auto* op = operands[opIndex];\n"
    "      int32_t** opIndices = &op->indices[nthBlock];\n"
    "      // Record previous indirection\n"
    "      auto backup =\n"
    "          reinterpret_cast<int32_t**>(operands[backups[column]]->base);\n"
    "      backup[nthBlock] = *opIndices;\n"
    "      if (!*opIndices) {\n"
    "        *opIndices = filterIndices + shared->blockBase;\n"
    "        state->indices = nullptr;\n"
    "      } else {\n"
    "        state->indices = *opIndices;\n"
    "        state->newIndices =\n"
    "            reinterpret_cast<int32_t*>(operands[newIndices[column]]->base);\n"
    "      }\n"
    "    }\n"
    "    __syncthreads();\n"
    "    // Every thread sees the decision on thred 0 above.\n"
    "    if (!state->indices) {\n"
    "      continue;\n"
    "    }\n"
    "    int32_t newIndex;\n"
    "    if (rowActive) {\n"
    "      newIndex = state->indices\n"
    "                     [filterIndices[shared->blockBase + threadIdx.x] -\n"
    "                      shared->blockBase];\n"
    "      state->newIndices[threadIdx.x] = newIndex;\n"
    "    }\n"
    "  }\n"
    "  __syncthreads();\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "__device__ T value(Operand* operands, OperandIndex opIdx) {\n"
    "  // Obsolete signature. call sites must be changed.\n"
    "  //    assert(false);\n"
    "  *(long*)0 = 0;\n"
    "  return T{};\n"
    "}\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_exec_WaveCore_cuh_reg =
    registerHeader(velox_experimental_wave_exec_WaveCore_cuh);
const char* velox_experimental_wave_exec_ExprKernel_h =
    "velox/experimental/wave/exec/ExprKernel.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <stdint.h>\n"
    "#include \"velox/experimental/wave/common/HashTable.h\"\n"
    "#include \"velox/experimental/wave/exec/ErrorCode.h\"\n"
    "#include \"velox/experimental/wave/vector/Operand.h\"\n"
    "\n"
    "/// Wave common instruction set. Instructions run a thread block wide\n"
    "/// and offer common operations like arithmetic, conditionals,\n"
    "/// filters, hash lookups, etc. Several vectorized operators can fuse\n"
    "/// into one instruction stream. Instruction streams may use shared\n"
    "/// memory depending on the instruction mix. The shared memory is to\n"
    "/// be allocated dynamically at kernel invocation.\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "/// Device-side state for group by\n"
    "struct DeviceAggregation {\n"
    "  /// hash table, nullptr if no grouping keys.\n"
    "  GpuHashTableBase* table{nullptr};\n"
    "\n"
    "  /// Device side atomic counting thread blocks working on the state. Assert\n"
    "  /// this is 0 at rehash or resupply of allocators.\n"
    "  uint32_t debugActiveBlockCounter{0};\n"
    "\n"
    "  // Byte size of a rowm rounded to next 8.\n"
    "  int32_t rowSize = 0;\n"
    "\n"
    "  /// Allocator for variable llength accumulators, if not provided by 'table'.\n"
    "  RowAllocator* allocator{nullptr};\n"
    "\n"
    "  char* singleRow{nullptr};\n"
    "\n"
    "  /// Number of int64_t* in groupResultRows. One for each potential\n"
    "  /// streamIdx of reading kernel.\n"
    "  int32_t numReadStreams{0};\n"
    "\n"
    "  /// Pointers to group by result row arrays. Subscripts is\n"
    "  /// '[streamIdx][row + 1]'. Element 0 is the row count.\n"
    "  uintptr_t** resultRowPointers{nullptr};\n"
    "};\n"
    "\n"
    "/// Parameters for creating/updating a group by.\n"
    "struct AggregationControl {\n"
    "  /// Pointer to page-aligned DeviceAggregation.\n"
    "  DeviceAggregation* head;\n"
    "\n"
    "  /// Size of block starting at 'head'. Must be set on first setup.\n"
    "  int64_t headSize{0};\n"
    "\n"
    "  /// For a rehashing command, old bucket array.\n"
    "  void* oldBuckets{nullptr};\n"
    "\n"
    "  /// Count of buckets starting at oldBuckets for rehash.\n"
    "  int64_t numOldBuckets{0};\n"
    "\n"
    "  /// Size of single row allocation.\n"
    "  int32_t rowSize{0};\n"
    "};\n"
    "\n"
    "struct AggregateReturn {\n"
    "  /// Count of rows in the table. Triggers rehash when high enough.\n"
    "  int64_t numDistinct;\n"
    "};\n"
    "\n"
    "// Return status for hash join build. Only indicates if more space is needed.\n"
    "struct BuildReturn {\n"
    "  // Flag 8 wide for alignment.\n"
    "  int64_t needMore{0};\n"
    "};\n"
    "\n"
    "/// Thread block wide status in Wave kernels\n"
    "struct WaveShared {\n"
    "  /// per lane status and row count.\n"
    "  BlockStatus* status;\n"
    "  Operand** operands;\n"
    "  void** states;\n"
    "\n"
    "  /// Every wrap in the kernel will also wrap these otherwise not accessed\n"
    "  /// Operands.\n"
    "  OperandIndex extraWraps;\n"
    "  int16_t numExtraWraps;\n"
    "  // The continue label where execution is to resume if continuing.\n"
    "  int16_t startLabel;\n"
    "  /// True if continuing the first instruction. The instruction will\n"
    "  /// pick up its lane status from blockStatus or an\n"
    "  /// instruction-specific source. The instruction must clear this\n"
    "  /// before executing the next instruction.\n"
    "  bool isContinue;\n"
    "\n"
    "  /// True if some lane needs a continue. Used inside a kernel to\n"
    "  /// indicate that the grid level status should be set to indicate\n"
    "  /// continue. Reset before end of instruction.\n"
    "  bool hasContinue;\n"
    "\n"
    "  /// True if doing an extra iteration to increase cardinality in\n"
    "  /// non-continue case, e.g. first kernel of 1:n join or unnest\n"
    "  /// producing non-first result for an active lane of input.\n"
    "  bool localContinue;\n"
    "\n"
    "  /// If true, all threads in block return before starting next instruction.\n"
    "  bool stop;\n"
    "  int32_t blockBase;\n"
    "  int32_t numRows;\n"
    "  /// Number of blocks for the program. Return statuses are at\n"
    "  /// '&blockStatus[numBlocks']\n"
    "  int32_t numBlocks;\n"
    "\n"
    "  // The branch of a multibranch kernel this block is doing.\n"
    "  int16_t programIdx;\n"
    "\n"
    "  /// Number of items in blockStatus covered by each TB.\n"
    "  int16_t numRowsPerThread;\n"
    "\n"
    "  /// Iteration counter, =0; < numRowsPerThread.\n"
    "  int16_t nthBlock;\n"
    "  int16_t streamIdx;\n"
    "\n"
    "  // Scratch data area. Size depends on shared memory size for instructions.\n"
    "  // Align 8.\n"
    "  int64_t data;\n"
    "};\n"
    "\n"
    "/// Parameters for a Wave kernel. All pointers are device side readable.\n"
    "struct KernelParams {\n"
    "  /// The first thread block with the program. Subscript is blockIdx.x.\n"
    "  int32_t* blockBase{nullptr};\n"
    "  // The ordinal of the program. All blocks with the same program have the same\n"
    "  // number here. Subscript is blockIdx.x. For compiled kernels, this gives the\n"
    "  // branch to follow for the TB at blockIdx.x.\n"
    "  int32_t* programIdx{nullptr};\n"
    "\n"
    "  /// The label where to start the execution. If nullptr,\n"
    "  /// 0. Otherwise subscript is programIdx. The active lanes are given\n"
    "  /// in 'blockStatus'. Used when restarting program at a specific\n"
    "  /// instruction, e.g. after allocating new memory on host. nullptr means first\n"
    "  /// launch, starting at 0.\n"
    "  int32_t* startPC{nullptr};\n"
    "\n"
    "  // For each exe, the start of the array of Operand*. Instructions reference\n"
    "  // operands via offset in this array. The subscript is\n"
    "  // programIdx[blockIdx.x].\n"
    "  Operand*** operands{nullptr};\n"
    "\n"
    "  // the status return block for each TB. The subscript is blockIdx.x -\n"
    "  // (blockBase[blockIdx.x] / kBlockSize). Shared between all programs.\n"
    "  BlockStatus* status{nullptr};\n"
    "\n"
    "  // Address of global states like hash tables. Subscript is 'programIdx' and\n"
    "  // next subscript is state id in the instruction.\n"
    "  void*** operatorStates;\n"
    "\n"
    "  /// first operand index for extra wraps. 'numExtraWraps' next\n"
    "  /// operands get wrapped by all wraps in the kernel.\n"
    "  OperandIndex extraWraps{0};\n"
    "  int16_t numExtraWraps{0};\n"
    "\n"
    "  /// Number of blocks in each program. gridDim.x can be a multiple if many\n"
    "  /// programs in launch.\n"
    "  int32_t numBlocks{0};\n"
    "\n"
    "  /// Number of elements of blockStatus covered by each TB.\n"
    "  int16_t numRowsPerThread{1};\n"
    "\n"
    "  /// Id of stream <stream ordinal within WaveDriver> + (<driverId of\n"
    "  /// WaveDriver> * <number of Drivers>.\n"
    "  int16_t streamIdx{0};\n"
    "};\n"
    "\n"
    "/// grid status for a hash join expand.\n"
    "struct HashJoinExpandGridStatus {\n"
    "  /// True if any lane in the grid has unfetched data. If true,\n"
    "  /// hashJoinExpandBlockstatus has the lane statuses. int32_t to support\n"
    "  /// atomics.\n"
    "  int32_t anyContinuable{0};\n"
    "  int32_t pad{0};\n"
    "};\n"
    "\n"
    "/// Tracks the state of a hash join probe between batches of output. Needed when\n"
    "/// the join increases cardinality.\n"
    "struct HashJoinExpandBlockStatus {\n"
    "  /// The next row in the hash table to look at. nullptr if all hits produced.\n"
    "  void* next[kBlockSize];\n"
    "};\n"
    "\n"
    "// Shared memory resident state for hash join expand.\n"
    "struct JoinShared {\n"
    "  HashJoinExpandGridStatus* gridStatus;\n"
    "  HashJoinExpandBlockStatus* blockStatus;\n"
    "  int32_t anyNext;\n"
    "  int32_t temp[kBlockSize / 32];\n"
    "};\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_exec_ExprKernel_h_reg =
    registerHeader(velox_experimental_wave_exec_ExprKernel_h);
const char* velox_experimental_wave_exec_Accumulators_cuh =
    "velox/experimental/wave/exec/Accumulators.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <cuda/atomic>\n"
    "#include \"velox/experimental/wave/common/Scan.cuh\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "template <typename T>\n"
    "__device__ inline void atomicInc(T* ptr, T inc) {\n"
    "  atomicAdd(ptr, inc);\n"
    "}\n"
    "\n"
    "template <>\n"
    "__device__ inline void atomicInc(int64_t* ptr, int64_t inc) {\n"
    "  atomicAdd((unsigned long long*)ptr, (unsigned long long)inc);\n"
    "}\n"
    "\n"
    "template <typename AccType, typename IncType, typename Reduce2>\n"
    "void __device__ __forceinline__ simpleAccumulate(\n"
    "    uint32_t peers,\n"
    "    int32_t leader,\n"
    "    int32_t lane,\n"
    "    AccType* acc,\n"
    "    uint32_t nulls,\n"
    "    uint32_t* aggNulls,\n"
    "    uint32_t aggNullMask,\n"
    "    IncType input,\n"
    "    bool inputIsNull,\n"
    "    Reduce2 reduce) {\n"
    "  IncType agg;\n"
    "  auto toUpdate = peers;\n"
    "  bool isAny = false;\n"
    "  for (;;) {\n"
    "    auto peer = __ffs(toUpdate) - 1;\n"
    "    auto inc = __shfl_sync(peers, input, peer);\n"
    "    auto isNull = __shfl_sync(peers, inputIsNull, peer);\n"
    "    if (lane == leader) {\n"
    "      if (!isNull) {\n"
    "        agg = !isAny ? inc : reduce(agg, inc);\n"
    "        isAny = true;\n"
    "      }\n"
    "      if (peer == leader) {\n"
    "        if (isAny) {\n"
    "          atomicInc(acc, agg);\n"
    "          if ((nulls & aggNullMask) == 0) {\n"
    "            atomicOr(aggNulls, aggNullMask);\n"
    "            nulls |= aggNullMask;\n"
    "          }\n"
    "        }\n"
    "        break;\n"
    "      }\n"
    "    }\n"
    "    if ((toUpdate &= toUpdate - 1) == 0) {\n"
    "      return;\n"
    "    }\n"
    "  }\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "__device__ __forceinline__ void sumReduce(\n"
    "    T input,\n"
    "    bool isNull,\n"
    "    ErrorCode laneStatus,\n"
    "    uint32_t nulls,\n"
    "    T* result,\n"
    "    uint32_t* resultNulls,\n"
    "    uint32_t nullMask,\n"
    "    void* smem) {\n"
    "  using Reduce = WarpReduce<T>;\n"
    "  using Reduce8 = WarpReduce<T, 8>;\n"
    "  constexpr int32_t kNumWarps = kBlockSize / kWarpThreads;\n"
    "  T* warpSum = reinterpret_cast<T*>(smem);\n"
    "  bool* warpAny = reinterpret_cast<bool*>(warpSum + kNumWarps);\n"
    "  bool nonNull = laneStatus == ErrorCode::kOk && !isNull;\n"
    "  bool warpFlag = __ballot_sync(0xffffffff, nonNull) != 0;\n"
    "  T laneValue = nonNull ? input : 0;\n"
    "  T warpResult = Reduce().reduce(laneValue, [](T x, T y) { return x + y; });\n"
    "  if ((threadIdx.x & 31) == 0) {\n"
    "    warpAny[threadIdx.x / kWarpThreads] = warpFlag;\n"
    "    warpSum[threadIdx.x / kWarpThreads] = warpResult;\n"
    "  }\n"
    "  __syncthreads();\n"
    "  bool anyAtAll;\n"
    "  if (threadIdx.x < kNumWarps) {\n"
    "    anyAtAll =\n"
    "        __ballot_sync(lowMask<uint32_t>(kNumWarps), warpAny[threadIdx.x]) != 0;\n"
    "  }\n"
    "  T finalSum;\n"
    "  if (threadIdx.x < kWarpThreads) {\n"
    "    finalSum = Reduce8().reduce(\n"
    "        threadIdx.x < kWarpThreads ? warpSum[threadIdx.x] : 0,\n"
    "        [](T x, T y) { return x + y; });\n"
    "    if (threadIdx.x == 0) {\n"
    "      if (anyAtAll) {\n"
    "        atomicInc(result, finalSum);\n"
    "        if ((nulls & nullMask) == 0) {\n"
    "          atomicOr(resultNulls, nullMask);\n"
    "        }\n"
    "      }\n"
    "    }\n"
    "  }\n"
    "}\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_exec_Accumulators_cuh_reg =
    registerHeader(velox_experimental_wave_exec_Accumulators_cuh);
const char* velox_experimental_wave_exec_Join_cuh =
    "velox/experimental/wave/exec/Join.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include \"velox/experimental/wave/common/HashTable.cuh\"\n"
    "#include \"velox/experimental/wave/exec/WaveCore.cuh\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "inline __device__ JoinShared* joinShared(WaveShared* shared) {\n"
    "  return reinterpret_cast<JoinShared*>(&shared->data);\n"
    "}\n"
    "\n"
    "template <int32_t gridStatusSize, int32_t blockStatusOffset>\n"
    "int64_t __device__ loadJoinNext(WaveShared* shared, ErrorCode& laneStatus) {\n"
    "  auto* status = blockStatus<HashJoinExpandBlockStatus>(\n"
    "      shared, gridStatusSize, blockStatusOffset);\n"
    "  if (threadIdx.x == 0) {\n"
    "    shared->numRows = 0;\n"
    "  }\n"
    "  auto h = reinterpret_cast<int64_t>(status->next[threadIdx.x]);\n"
    "  laneStatus = h != 0 ? ErrorCode::kOk : ErrorCode::kInactive;\n"
    "  return h;\n"
    "}\n"
    "\n"
    "template <\n"
    "    typename RowType,\n"
    "    int32_t indicesIdx,\n"
    "    int32_t gridstatusOffset,\n"
    "    int32_t gridStatusSize,\n"
    "    int32_t blockStatusOffset,\n"
    "    int32_t hitsIdx>\n"
    "bool __device__ __forceinline__ joinResult(\n"
    "    int64_t& hitAsInt,\n"
    "    bool filterResult,\n"
    "    bool joinContinue,\n"
    "    ErrorCode laneStatus,\n"
    "    WaveShared* shared,\n"
    "    bool hasDuplicates) {\n"
    "  RowType* hit =\n"
    "      reinterpret_cast<RowType*>(laneStatus == ErrorCode::kOk ? hitAsInt : 0);\n"
    "  if (threadIdx.x == 0) {\n"
    "    auto* j = joinShared(shared);\n"
    "    if (hasDuplicates) {\n"
    "      j->gridStatus =\n"
    "          gridStatus<HashJoinExpandGridStatus>(shared, gridstatusOffset);\n"
    "      j->blockStatus = blockStatus<HashJoinExpandBlockStatus>(\n"
    "          shared, gridStatusSize, blockStatusOffset);\n"
    "      j->anyNext = 0;\n"
    "    } else {\n"
    "      j->gridStatus = nullptr;\n"
    "      j->blockStatus = nullptr;\n"
    "    }\n"
    "  }\n"
    "  if (!joinContinue) {\n"
    "    auto nth = exclusiveSum<int32_t, kBlockSize>(\n"
    "        (hit ? filterResult : 0), &shared->numRows, joinShared(shared)->temp);\n"
    "    if (hit && filterResult) {\n"
    "      auto* hits = reinterpret_cast<RowType**>(shared->operands[hitsIdx]->base);\n"
    "      hits[shared->blockBase + nth] = hit;\n"
    "      auto* indices =\n"
    "          reinterpret_cast<int32_t*>(shared->operands[indicesIdx]->base);\n"
    "      indices[shared->blockBase + nth] = shared->blockBase + threadIdx.x;\n"
    "    }\n"
    "    __syncthreads();\n"
    "    if (threadIdx.x == kBlockSize - 1) {\n"
    "      shared->numRows = nth + (hit && filterResult);\n"
    "    }\n"
    "\n"
    "    if (!hasDuplicates) {\n"
    "      // syncthreads in caller.\n"
    "      return false;\n"
    "    }\n"
    "    RowType* next = nullptr;\n"
    "    if (hit) {\n"
    "      next = *hit->nextPtr();\n"
    "      hit = next;\n"
    "      hitAsInt = reinterpret_cast<int64_t>(hit);\n"
    "    }\n"
    "    joinShared(shared)->blockStatus->next[threadIdx.x] = next;\n"
    "    uint32_t flags = __ballot_sync(0xffffffff, next != nullptr);\n"
    "    if ((threadIdx.x & 31) == 0) {\n"
    "      if (flags) {\n"
    "        atomicOr(&joinShared(shared)->anyNext, flags);\n"
    "        joinShared(shared)->gridStatus->anyContinuable = true;\n"
    "      }\n"
    "    }\n"
    "    __syncthreads();\n"
    "    if (threadIdx.x == 0) {\n"
    "      shared->localContinue =\n"
    "          shared->numRows < kBlockSize - 64 && joinShared(shared)->anyNext;\n"
    "    }\n"
    "    __syncthreads();\n"
    "    // All threads return the same. true if there is space in the output and\n"
    "    // nexts to look at.\n"
    "    return shared->localContinue;\n"
    "  }\n"
    "  // We come here when there are  places to fill above shared->numRows.\n"
    "  bool laneFull = false;\n"
    "  if (hit && filterResult) {\n"
    "    auto row = atomicAdd(&shared->numRows, 1);\n"
    "    if (row < kBlockSize) {\n"
    "      auto* hits = reinterpret_cast<RowType**>(shared->operands[hitsIdx]->base);\n"
    "      auto* indices =\n"
    "          reinterpret_cast<int32_t*>(shared->operands[indicesIdx]->base);\n"
    "      indices[shared->blockBase + row] = shared->blockBase + threadIdx.x;\n"
    "      hits[shared->blockBase + row] = hit;\n"
    "    } else {\n"
    "      laneFull = true;\n"
    "    }\n"
    "  }\n"
    "  // Make sure joinShared is seen on all threads.\n"
    "  __syncthreads();\n"
    "  if (!laneFull && hit) {\n"
    "    auto* next = *hit->nextPtr();\n"
    "    joinShared(shared)->blockStatus->next[threadIdx.x] = next;\n"
    "    hitAsInt = reinterpret_cast<int64_t>(next);\n"
    "  }\n"
    "\n"
    "  if (threadIdx.x == 0 && shared->numRows > kBlockSize) {\n"
    "    shared->numRows = kBlockSize;\n"
    "  }\n"
    "  uint32_t flags = __ballot_sync(\n"
    "      0xffffffff,\n"
    "      joinShared(shared)->blockStatus->next[threadIdx.x] != nullptr);\n"
    "  if ((threadIdx.x & 31) == 0) {\n"
    "    if (flags) {\n"
    "      atomicOr(&joinShared(shared)->anyNext, flags);\n"
    "    }\n"
    "  }\n"
    "  __syncthreads();\n"
    "  if (threadIdx.x == 0) {\n"
    "    if (joinShared(shared)->anyNext) {\n"
    "      joinShared(shared)->gridStatus->anyContinuable = 1;\n"
    "    }\n"
    "    shared->localContinue =\n"
    "        shared->numRows < kBlockSize - 32 && joinShared(shared)->anyNext;\n"
    "  }\n"
    "  __syncthreads();\n"
    "  return shared->localContinue;\n"
    "}\n"
    "\n"
    "template <typename RowType, int32_t hitsIdx, typename CopyRow>\n"
    "void __device__ __forceinline__\n"
    "joinRow(ErrorCode laneStatus, WaveShared* shared, CopyRow copy) {\n"
    "  if (laneStatus == ErrorCode::kOk) {\n"
    "    RowType** hits =\n"
    "        reinterpret_cast<RowType**>(shared->operands[hitsIdx]->base);\n"
    "    copy(\n"
    "        hits[shared->blockBase + threadIdx.x], shared->blockBase + threadIdx.x);\n"
    "  }\n"
    "}\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_exec_Join_cuh_reg =
    registerHeader(velox_experimental_wave_exec_Join_cuh);
const char* velox_experimental_wave_common_HashTable_h =
    "velox/experimental/wave/common/HashTable.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#ifndef __CUDACC_RTC__\n"
    "#include <string.h>\n"
    "#include <string>\n"
    "#endif\n"
    "#include <stdint.h>\n"
    "\n"
    "/// Structs for tagged GPU hash table. Can be inclued in both Velox .cpp and\n"
    "/// .cu.\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "/// A 32 byte tagged bucket with 4 tags, 4 flag bytes and 4 6-byte\n"
    "/// pointers. Fits in one 32 byte GPU cache sector.\n"
    "struct GpuBucketMembers {\n"
    "  static constexpr int32_t kNumSlots = 4;\n"
    "\n"
    "  uint32_t tags;\n"
    "  uint32_t flags;\n"
    "  uint16_t data[12];\n"
    "\n"
    "#ifndef __CUDACC_RTC__\n"
    "  template <typename T>\n"
    "  T* testingLoad(int32_t idx) {\n"
    "    auto uptr = static_cast<uint64_t>(data[8 + idx]) << 32;\n"
    "    uptr |= reinterpret_cast<uint32_t*>(data)[idx];\n"
    "    return reinterpret_cast<T*>(uptr);\n"
    "  }\n"
    "#endif\n"
    "};\n"
    "\n"
    "template <typename T, int32_t kSize>\n"
    "class FreeSetBase {\n"
    "  int32_t full_{0};\n"
    "  int32_t empty_{1};\n"
    "  unsigned long long bits_[kSize / 64] = {};\n"
    "  T items_[kSize] = {};\n"
    "};\n"
    "\n"
    "#ifndef __CUDACC_RTC__\n"
    "static inline int32_t roundUp64(int32_t value) {\n"
    "  return (value + 64 - 1) / 64 * 64;\n"
    "}\n"
    "#endif\n"
    "\n"
    "/// Range of addresses. fixed length from bottom and variable length from top.\n"
    "/// if 'rowOffset' goes above 'rowLimit' then rows are full. If 'stringOffset'\n"
    "/// goes below 'rowLimit' then strings are full.\n"
    "struct AllocationRange {\n"
    "  AllocationRange() = default;\n"
    "\n"
    "#ifndef __CUDACC_RTC__\n"
    "  AllocationRange(\n"
    "      uintptr_t base,\n"
    "      uint32_t capacity,\n"
    "      uint32_t rowLimit,\n"
    "      int32_t rowSize)\n"
    "      : fixedFull(false),\n"
    "        variableFull(false),\n"
    "        base(base),\n"
    "        capacity(capacity),\n"
    "        rowLimit(rowLimit),\n"
    "        // We leave n words of 64 bits, one bit for each possible row within\n"
    "        // 'capacity' below first row.\n"
    "        firstRowOffset(roundUp64(capacity / rowSize) / 8),\n"
    "        rowOffset(firstRowOffset),\n"
    "        stringOffset(capacity) {\n"
    "    ::memset(reinterpret_cast<char*>(base), 0, firstRowOffset);\n"
    "  }\n"
    "\n"
    "  AllocationRange(AllocationRange&& other) {\n"
    "    *this = std::move(other);\n"
    "  }\n"
    "\n"
    "  AllocationRange& operator=(const AllocationRange& other) = default;\n"
    "\n"
    "  void operator=(AllocationRange&& other) {\n"
    "    *this = other;\n"
    "    memset(&other, 0, sizeof(AllocationRange));\n"
    "  }\n"
    "\n"
    "  int64_t availableFixed() {\n"
    "    return rowOffset > rowLimit ? 0 : rowLimit - rowOffset;\n"
    "  }\n"
    "\n"
    "  /// Raises rowLimit by up to 'size'. Returns the amount raised.\n"
    "  int32_t raiseRowLimit(int32_t size) {\n"
    "    auto space = stringOffset - rowLimit;\n"
    "    auto delta = std::min<int32_t>(space, size);\n"
    "    rowLimit += delta;\n"
    "    if (delta > 0) {\n"
    "      fixedFull = false;\n"
    "    }\n"
    "    return delta;\n"
    "  }\n"
    "\n"
    "  void clearOverflows(int32_t rowSize) {\n"
    "    if (rowOffset > rowLimit) {\n"
    "      // Set 'rowOffset' to the greatest multipl of rowSize from 'base' that is\n"
    "      // below the limit.\n"
    "      int32_t numRows = (rowLimit - firstRowOffset) / rowSize;\n"
    "      rowOffset = firstRowOffset + numRows * rowSize;\n"
    "    }\n"
    "    if (stringOffset < rowLimit) {\n"
    "      stringOffset = rowLimit;\n"
    "    }\n"
    "  }\n"
    "\n"
    "  /// Sets row limit so that there are at most 'target' allocatable\n"
    "  /// bytes. If available space is less than the target, the available\n"
    "  /// space is not changed. Returns 'target' minus the available space\n"
    "  // in 'this'.\n"
    "  int32_t trimFixed(int32_t target) {\n"
    "    auto available = rowLimit - rowOffset;\n"
    "    if (available > target) {\n"
    "      rowLimit = rowOffset + target;\n"
    "    }\n"
    "    return target - (rowLimit - rowOffset);\n"
    "  }\n"
    "\n"
    "  /// True if in post-default constructed state.\n"
    "  bool empty() {\n"
    "    return capacity == 0;\n"
    "  }\n"
    "\n"
    "  std::string toString(int32_t rowSize);\n"
    "#endif\n"
    "  bool fixedFull{true};\n"
    "  bool variableFull{true};\n"
    "  /// Number of the partition. Used when filing away ranges on the control\n"
    "  /// plane.\n"
    "  uint8_t partition{0};\n"
    "  uint64_t base{0};\n"
    "  uint32_t capacity{0};\n"
    "  uint32_t rowLimit{0};\n"
    "  int32_t firstRowOffset{0};\n"
    "  uint32_t rowOffset{0};\n"
    "  uint32_t stringOffset{0};\n"
    "};\n"
    "\n"
    "/// A device arena for device side allocation.\n"
    "struct HashPartitionAllocator {\n"
    "  static constexpr uint32_t kEmpty = ~0;\n"
    "\n"
    "#ifndef __CUDACC_RTC__\n"
    "  HashPartitionAllocator(\n"
    "      char* data,\n"
    "      uint32_t capacity,\n"
    "      uint32_t rowLimit,\n"
    "      uint32_t rowSize)\n"
    "      : rowSize(rowSize) {\n"
    "    ranges[0] = AllocationRange(\n"
    "        reinterpret_cast<uintptr_t>(data), capacity, rowLimit, rowSize);\n"
    "  }\n"
    "  /// Returns the available bytes  in fixed size pools.\n"
    "  int64_t availableFixed() {\n"
    "    return ranges[0].availableFixed() + ranges[1].availableFixed();\n"
    "  }\n"
    "\n"
    "  /// Sets allocated offsets to limit if these are over the\n"
    "  /// limit. They are over limit and available is negative after many\n"
    "  /// concurrent failed allocations.\n"
    "  void clearOverflows() {\n"
    "    ranges[0].clearOverflows(rowSize);\n"
    "    ranges[1].clearOverflows(rowSize);\n"
    "  }\n"
    "\n"
    "  /// Raises the row limit by up to size bytes. Returns th amount raised.\n"
    "  int32_t raiseRowLimits(int32_t size) {\n"
    "    auto raised = ranges[0].raiseRowLimit(size);\n"
    "    return raised + ranges[1].raiseRowLimit(size - raised);\n"
    "  }\n"
    "\n"
    "  /// sets rowLimit so that there will be at most 'maxSize' bytes of fixed\n"
    "  /// length.\n"
    "  void trimRows(int32_t target) {\n"
    "    target = ranges[0].trimFixed(target);\n"
    "    ranges[1].trimFixed(target);\n"
    "  }\n"
    "\n"
    "  std::string toString();\n"
    "#endif\n"
    "\n"
    "  const int32_t rowSize{0};\n"
    "  AllocationRange ranges[2];\n"
    "};\n"
    "\n"
    "/// Implementation of HashPartitionAllocator, defined in .cuh.\n"
    "struct RowAllocator;\n"
    "\n"
    "enum class ProbeState : uint8_t { kDone, kMoreValues, kNeedSpace, kRetry };\n"
    "\n"
    "/// Operands for one TB of hash probe.\n"
    "struct HashProbe {\n"
    "  /// The number of input rows processed by each thread of a TB. The base index\n"
    "  /// for a block in the arrays in 'this' is 'numRowsPerThread * blockDim.x *\n"
    "  /// blockIdx.x'\n"
    "  int32_t numRowsPerThread{1};\n"
    "\n"
    "  /// Count of probe keys for each TB. Subscript is blockIdx.x.\n"
    "  int32_t* numRows;\n"
    "\n"
    "  /// Data for probe keys. To be interpreted by Ops of the probe, no\n"
    "  /// fixed format.\n"
    "  void* keys;\n"
    "\n"
    "  /// Hash numbers for probe keys.\n"
    "  uint64_t* hashes;\n"
    "\n"
    "  /// List of input rows to retry in kernel. Sized to one per row of\n"
    "  /// input. Used inside kernel, not meaningful after return. Sample\n"
    "  /// use case is another warp updating the same row.\n"
    "  int32_t* kernelRetries1;\n"
    "  int32_t* kernelRetries2;\n"
    "\n"
    "  /// List of input rows to retry after host updated state. Sized to\n"
    "  /// one per row of input. The reason for a host side retry is\n"
    "  /// needing more space. The host will decide to allocate/spill/error\n"
    "  /// out.\n"
    "  int32_t* hostRetries;\n"
    "\n"
    "  /// Count of valid items in 'hostRetries'. The subscript is blockIdx.x.\n"
    "  int32_t* numHostRetries;\n"
    "\n"
    "  /// Space in 'hits' and 'hitRows'. Should be a multiple of probe block width.\n"
    "  int32_t maxHits{0};\n"
    "\n"
    "  /// Row numbers for hits. Indices into 'hashes'.\n"
    "  int32_t* hitRows{nullptr};\n"
    "\n"
    "  // Optional payload rows hitting from a probe.\n"
    "  void** hits{nullptr};\n"
    "};\n"
    "\n"
    "struct GpuBucket;\n"
    "\n"
    "struct GpuHashTableBase {\n"
    "#ifndef __CUDACC_RTC__\n"
    "  GpuHashTableBase(\n"
    "      GpuBucket* buckets,\n"
    "      int32_t sizeMask,\n"
    "      int32_t partitionMask,\n"
    "      RowAllocator* allocators)\n"
    "      : buckets(buckets),\n"
    "        sizeMask(sizeMask),\n"
    "        partitionMask(partitionMask),\n"
    "        allocators(allocators),\n"
    "        maxEntries(((sizeMask + 1) * GpuBucketMembers::kNumSlots) / 6 * 5) {}\n"
    "#endif\n"
    "  /// Bucket array. Size is 'sizeMask + 1'.\n"
    "  GpuBucket* buckets{nullptr};\n"
    "\n"
    "  // Mask to extract index into 'buckets' from a hash number. a\n"
    "  // sizemask of 63 means 64 buckets, which is up to 256 entries.\n"
    "  uint32_t sizeMask;\n"
    "\n"
    "  // Translates a hash number to a partition number '(hash >> 41) &\n"
    "  // partitionMask gives a physical partition of the table. Used as\n"
    "  // index into 'allocators'.\n"
    "  uint32_t partitionMask{0};\n"
    "\n"
    "  /// true if this is a join table where duplicates exist (at least one\n"
    "  /// next link is non-nullptr). int32_t to allow atomic ops.\n"
    "  int32_t hasDuplicates{0};\n"
    "\n"
    "  /// A RowAllocator for each partition.\n"
    "  RowAllocator* allocators;\n"
    "\n"
    "  /// Count of entries in buckets.\n"
    "  int64_t numDistinct{0};\n"
    "\n"
    "  /// Maximum number of entries. Incremented by atomic add at warp\n"
    "  /// level. Must be at least 32 belo count of slots. If numDistinct\n"
    "  /// after add exceeds max, the inserts in the warp fail and will be\n"
    "  /// retried after rehash.\n"
    "  int64_t maxEntries{0};\n"
    "};\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_HashTable_h_reg =
    registerHeader(velox_experimental_wave_common_HashTable_h);
const char* velox_experimental_wave_common_HashTable_cuh =
    "velox/experimental/wave/common/HashTable.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <cuda/atomic>\n"
    "#include <cuda/semaphore>\n"
    "#include \"velox/experimental/wave/common/BitUtil.cuh\"\n"
    "#include \"velox/experimental/wave/common/Hash.h\"\n"
    "#include \"velox/experimental/wave/common/HashTable.h\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "#define GPF() *(long*)0 = 0\n"
    "\n"
    "template <typename T, typename U>\n"
    "inline __device__ cuda::atomic<T, cuda::thread_scope_device>* asDeviceAtomic(\n"
    "    U* ptr) {\n"
    "  return reinterpret_cast<cuda::atomic<T, cuda::thread_scope_device>*>(ptr);\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline bool __device__ atomicTryLock(T* lock) {\n"
    "  return 0 ==\n"
    "      asDeviceAtomic<int32_t>(lock)->exchange(1, cuda::memory_order_consume);\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline void __device__ atomicUnlock(T* lock) {\n"
    "  asDeviceAtomic<int32_t>(lock)->store(0, cuda::memory_order_release);\n"
    "}\n"
    "namespace detail {\n"
    "template <typename T>\n"
    "inline __device__ T* allocateFixed(AllocationRange& range, int32_t size) {\n"
    "  if (range.fixedFull) {\n"
    "    return nullptr;\n"
    "  }\n"
    "  auto offset = atomicAdd(&range.rowOffset, size);\n"
    "  if (offset + size <= range.rowLimit) {\n"
    "    return reinterpret_cast<T*>(range.base + offset);\n"
    "  }\n"
    "  range.fixedFull = true;\n"
    "  return nullptr;\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "inline __device__ T* allocate(AllocationRange& range, int32_t count) {\n"
    "  if (range.variableFull) {\n"
    "    return nullptr;\n"
    "  }\n"
    "  auto size = sizeof(T) * count;\n"
    "  auto offset = atomicAdd(&range.stringOffset, -size);\n"
    "  if (offset - size >= range.rowLimit) {\n"
    "    return reinterpret_cast<T*>(range.base + offset);\n"
    "  }\n"
    "  range.variableFull = true;\n"
    "  return nullptr;\n"
    "}\n"
    "} // namespace detail\n"
    "\n"
    "/// Allocator subclass that defines device member functions.\n"
    "struct RowAllocator : public HashPartitionAllocator {\n"
    "  template <typename T>\n"
    "  T* __device__ allocateRow() {\n"
    "    if (!ranges[0].fixedFull) {\n"
    "      auto ptr = detail::allocateFixed<T>(ranges[0], rowSize);\n"
    "      if (ptr) {\n"
    "        return ptr;\n"
    "      }\n"
    "      if (ranges[1].fixedFull) {\n"
    "        return nullptr;\n"
    "      }\n"
    "    }\n"
    "    return detail::allocateFixed<T>(ranges[1], rowSize);\n"
    "  }\n"
    "\n"
    "  template <typename T>\n"
    "  T* __device__ allocate(int32_t count) {\n"
    "    if (!ranges[0].variableFull) {\n"
    "      auto ptr = detail::allocate<T>(ranges[0], count);\n"
    "      if (ptr) {\n"
    "        return ptr;\n"
    "      }\n"
    "      if (ranges[1].variableFull) {\n"
    "        return nullptr;\n"
    "      }\n"
    "    }\n"
    "    return detail::allocate<T>(ranges[1], count);\n"
    "  }\n"
    "\n"
    "  template <typename T>\n"
    "  bool __device__ markRowFree(T* row) {\n"
    "    auto ptr = reinterpret_cast<uintptr_t>(row);\n"
    "    AllocationRange* rowRange;\n"
    "    if (ptr >= ranges[0].base + ranges[0].firstRowOffset &&\n"
    "        ptr < ranges[0].base + ranges[0].rowLimit) {\n"
    "      rowRange = &ranges[0];\n"
    "    } else if (\n"
    "        ptr >= ranges[1].base + ranges[1].firstRowOffset &&\n"
    "        ptr < ranges[1].base + ranges[1].rowLimit) {\n"
    "      rowRange = &ranges[1];\n"
    "    } else {\n"
    "      return false;\n"
    "    }\n"
    "    int32_t idx = (ptr - (rowRange->base + rowRange->firstRowOffset)) / rowSize;\n"
    "    atomicOr(\n"
    "        reinterpret_cast<uint32_t*>(rowRange->base) + (idx >> 5),\n"
    "        1 << (idx & 31));\n"
    "    return true;\n"
    "  }\n"
    "};\n"
    "\n"
    "inline uint8_t __device__ hashTag(uint64_t h) {\n"
    "  return 0x80 | (h >> 32);\n"
    "}\n"
    "\n"
    "struct GpuBucket : public GpuBucketMembers {\n"
    "  template <typename RowType>\n"
    "  inline RowType* __device__ load(int32_t idx) const {\n"
    "    uint64_t uptr = reinterpret_cast<const uint32_t*>(&data)[idx];\n"
    "    if (uptr == 0) {\n"
    "      return nullptr;\n"
    "    }\n"
    "    uptr |= static_cast<uint64_t>(data[idx + 8]) << 32;\n"
    "    return reinterpret_cast<RowType*>(uptr);\n"
    "  }\n"
    "\n"
    "  template <typename RowType>\n"
    "  inline RowType* __device__ loadConsume(int32_t idx) {\n"
    "    uint64_t uptr =\n"
    "        asDeviceAtomic<uint32_t>(&data)[idx].load(cuda::memory_order_consume);\n"
    "    if (uptr == 0) {\n"
    "      return nullptr;\n"
    "    }\n"
    "    uptr |= static_cast<uint64_t>(data[idx + 8]) << 32;\n"
    "    return reinterpret_cast<RowType*>(uptr);\n"
    "  }\n"
    "\n"
    "  template <typename RowType>\n"
    "  inline RowType* __device__ loadWithWait(int32_t idx) {\n"
    "    RowType* hit;\n"
    "    do {\n"
    "      // It could be somebody inserted the tag but did not fill in the\n"
    "      // pointer. The pointer is coming in a few clocks.\n"
    "      hit = loadConsume<RowType>(idx);\n"
    "    } while (!hit);\n"
    "    return hit;\n"
    "  }\n"
    "\n"
    "  inline void __device__ store(int32_t idx, void* ptr) {\n"
    "    auto uptr = reinterpret_cast<uint64_t>(ptr);\n"
    "    data[8 + idx] = uptr >> 32;\n"
    "    // The high part must be seen if the low part is seen.\n"
    "    asDeviceAtomic<uint32_t>(&data)[idx].store(\n"
    "        uptr, cuda::memory_order_release);\n"
    "  }\n"
    "\n"
    "  bool __device__ addNewTag(uint8_t tag, uint32_t oldTags, uint8_t tagShift) {\n"
    "    uint32_t newTags = oldTags | ((static_cast<uint32_t>(tag) << tagShift));\n"
    "    return (oldTags == atomicCAS(&tags, oldTags, newTags));\n"
    "  }\n"
    "};\n"
    "\n"
    "class GpuHashTable : public GpuHashTableBase {\n"
    " public:\n"
    "  static constexpr int32_t kExclusive = 1;\n"
    "\n"
    "  template <typename RowType, typename Compare>\n"
    "  RowType* __device__ joinProbe(uint64_t h, Compare compare) {\n"
    "    uint32_t tagWord = hashTag(h);\n"
    "    tagWord |= tagWord << 8;\n"
    "    tagWord = tagWord | tagWord << 16;\n"
    "    auto bucketIdx = h & sizeMask;\n"
    "    for (;;) {\n"
    "      GpuBucket* bucket = buckets + bucketIdx;\n"
    "      auto tags = bucket->tags;\n"
    "      auto hits = __vcmpeq4(tags, tagWord) & 0x01010101;\n"
    "      while (hits) {\n"
    "        auto hitIdx = (__ffs(hits) - 1) / 8;\n"
    "        auto* hit = bucket->load<RowType>(hitIdx);\n"
    "        if (compare(hit)) {\n"
    "          return hit;\n"
    "        }\n"
    "        hits = hits & (hits - 1);\n"
    "      }\n"
    "      if (__vcmpeq4(tags, 0)) {\n"
    "        return nullptr;\n"
    "      }\n"
    "      bucketIdx = (bucketIdx + 1) & sizeMask;\n"
    "    }\n"
    "  }\n"
    "\n"
    "  template <typename RowType, typename Init>\n"
    "  bool __device__ addJoinRow(Init init) {\n"
    "    auto* row = allocators[0].allocateRow<RowType>();\n"
    "    if (!row) {\n"
    "      return false;\n"
    "    }\n"
    "    init(row);\n"
    "    return true;\n"
    "  }\n"
    "\n"
    "  template <typename RowType, typename Ops>\n"
    "  void __device__\n"
    "  updatingProbe(int32_t i, int32_t lane, bool isLaneActive, Ops& ops) {\n"
    "    uint32_t laneMask = __ballot_sync(0xffffffff, isLaneActive);\n"
    "    if (!isLaneActive) {\n"
    "      return;\n"
    "    }\n"
    "    auto h = ops.hash(i);\n"
    "    uint32_t tagWord = hashTag(h);\n"
    "    tagWord |= tagWord << 8;\n"
    "    tagWord = tagWord | tagWord << 16;\n"
    "    auto bucketIdx = h & sizeMask;\n"
    "    uint32_t misses = 0;\n"
    "    RowType* hit = nullptr;\n"
    "    RowType* toInsert = nullptr;\n"
    "    int32_t hitIdx;\n"
    "    GpuBucket* bucket;\n"
    "    uint32_t tags;\n"
    "    for (;;) {\n"
    "      bucket = buckets + bucketIdx;\n"
    "    reprobe:\n"
    "      tags = asDeviceAtomic<uint32_t>(&bucket->tags)\n"
    "                 ->load(cuda::memory_order_consume);\n"
    "      auto hits = __vcmpeq4(tags, tagWord) & 0x01010101;\n"
    "      while (hits) {\n"
    "        hitIdx = (__ffs(hits) - 1) / 8;\n"
    "        auto candidate = bucket->loadWithWait<RowType>(hitIdx);\n"
    "        if (ops.compare(this, candidate, i)) {\n"
    "          if (toInsert) {\n"
    "            ops.freeInsertable(this, toInsert, h);\n"
    "          }\n"
    "          hit = candidate;\n"
    "          break;\n"
    "        }\n"
    "        hits = hits & (hits - 1);\n"
    "      }\n"
    "      if (hit) {\n"
    "        break;\n"
    "      }\n"
    "      misses = __vcmpeq4(tags, 0);\n"
    "      if (misses) {\n"
    "        auto success = ops.insert(\n"
    "            this, partitionIdx(h), bucket, misses, tags, tagWord, i, toInsert);\n"
    "        if (success == ProbeState::kRetry) {\n"
    "          goto reprobe;\n"
    "        }\n"
    "        if (success == ProbeState::kNeedSpace) {\n"
    "          ops.addHostRetry(i);\n"
    "          hit = nullptr;\n"
    "          break;\n"
    "        }\n"
    "        hit = toInsert;\n"
    "        break;\n"
    "      }\n"
    "      bucketIdx = (bucketIdx + 1) & sizeMask;\n"
    "    }\n"
    "    // Every lane has a hit, or a nullptr if out of space.\n"
    "    uint32_t peers = __match_any_sync(laneMask, reinterpret_cast<int64_t>(hit));\n"
    "    if (hit) {\n"
    "      int32_t leader = (kWarpThreads - 1) - __clz(peers);\n"
    "      RowType* writable = nullptr;\n"
    "      if (lane == leader) {\n"
    "        writable = ops.getExclusive(this, bucket, hit, hitIdx);\n"
    "      }\n"
    "      auto toUpdate = peers;\n"
    "      while (toUpdate) {\n"
    "        auto peer = __ffs(toUpdate) - 1;\n"
    "        auto idxToUpdate = __shfl_sync(peers, i, peer);\n"
    "        if (lane == leader) {\n"
    "          ops.update(this, bucket, writable, idxToUpdate);\n"
    "        }\n"
    "        toUpdate &= toUpdate - 1;\n"
    "      }\n"
    "      if (lane == leader) {\n"
    "        ops.writeDone(writable);\n"
    "      }\n"
    "    }\n"
    "  }\n"
    "\n"
    "  template <\n"
    "      typename RowType,\n"
    "      typename Ops,\n"
    "      typename Compare,\n"
    "      typename Init,\n"
    "      typename Update>\n"
    "  void __device__ updatingProbe(\n"
    "      int32_t i,\n"
    "      int32_t lane,\n"
    "      bool isLaneActive,\n"
    "      Ops& ops,\n"
    "      Compare compare,\n"
    "      Init init,\n"
    "      Update update) {\n"
    "    uint32_t laneMask = __ballot_sync(0xffffffff, isLaneActive);\n"
    "    if (!isLaneActive) {\n"
    "      return;\n"
    "    }\n"
    "    auto h = ops.hash(i);\n"
    "    uint32_t tagWord = hashTag(h);\n"
    "    tagWord |= tagWord << 8;\n"
    "    tagWord = tagWord | tagWord << 16;\n"
    "    auto bucketIdx = h & sizeMask;\n"
    "    uint32_t misses = 0;\n"
    "    RowType* hit = nullptr;\n"
    "    RowType* toInsert = nullptr;\n"
    "    int32_t hitIdx;\n"
    "    GpuBucket* bucket;\n"
    "    uint32_t tags;\n"
    "    for (;;) {\n"
    "      bucket = buckets + bucketIdx;\n"
    "    reprobe:\n"
    "      tags = asDeviceAtomic<uint32_t>(&bucket->tags)\n"
    "                 ->load(cuda::memory_order_consume);\n"
    "      auto hits = __vcmpeq4(tags, tagWord) & 0x01010101;\n"
    "      while (hits) {\n"
    "        hitIdx = (__ffs(hits) - 1) / 8;\n"
    "        auto candidate = bucket->loadWithWait<RowType>(hitIdx);\n"
    "        if (compare(candidate)) {\n"
    "          if (toInsert) {\n"
    "            ops.freeInsertable(this, toInsert, h);\n"
    "          }\n"
    "          hit = candidate;\n"
    "          break;\n"
    "        }\n"
    "        hits = hits & (hits - 1);\n"
    "      }\n"
    "      if (hit) {\n"
    "        break;\n"
    "      }\n"
    "      misses = __vcmpeq4(tags, 0);\n"
    "      if (misses) {\n"
    "        auto success = ops.insert(\n"
    "            this,\n"
    "            partitionIdx(h),\n"
    "            bucket,\n"
    "            misses,\n"
    "            tags,\n"
    "            tagWord,\n"
    "            i,\n"
    "            toInsert,\n"
    "            init);\n"
    "        if (success == ProbeState::kRetry) {\n"
    "          goto reprobe;\n"
    "        }\n"
    "        if (success == ProbeState::kNeedSpace) {\n"
    "          ops.addHostRetry(i);\n"
    "          hit = nullptr;\n"
    "          break;\n"
    "        }\n"
    "        hit = toInsert;\n"
    "        break;\n"
    "      }\n"
    "      bucketIdx = (bucketIdx + 1) & sizeMask;\n"
    "    }\n"
    "    // Every lane has a hit, or a nullptr if out of space.\n"
    "    uint32_t peers = __match_any_sync(laneMask, reinterpret_cast<int64_t>(hit));\n"
    "    if (hit) {\n"
    "      int32_t leader = (kWarpThreads - 1) - __clz(peers);\n"
    "      RowType* writable = nullptr;\n"
    "      if (lane == leader) {\n"
    "        writable = ops.getExclusive(this, bucket, hit, hitIdx);\n"
    "      }\n"
    "      update(this, hit, peers, leader, lane);\n"
    "      if (lane == leader) {\n"
    "        ops.writeDone(writable);\n"
    "      }\n"
    "    }\n"
    "  }\n"
    "\n"
    "  template <typename RowType, typename Ops>\n"
    "  void __device__\n"
    "  rehash(GpuBucket* oldBuckets, int32_t numOldBuckets, Ops ops) {\n"
    "    auto stride = blockDim.x * gridDim.x;\n"
    "    for (auto idx = threadIdx.x + blockDim.x * blockIdx.x; idx < numOldBuckets;\n"
    "         idx += stride) {\n"
    "      for (auto slot = 0; slot < GpuBucketMembers::kNumSlots; ++slot) {\n"
    "        auto* row = oldBuckets[idx].load<RowType>(slot);\n"
    "        if (row) {\n"
    "          uint64_t h = ops.hashRow(row);\n"
    "          auto bucketIdx = h & sizeMask;\n"
    "          uint32_t tagWord = hashTag(h);\n"
    "          tagWord |= tagWord << 8;\n"
    "          tagWord = tagWord | tagWord << 16;\n"
    "\n"
    "          for (;;) {\n"
    "            GpuBucket* bucket = buckets + bucketIdx;\n"
    "          reprobe:\n"
    "            uint32_t tags = asDeviceAtomic<uint32_t>(&bucket->tags)\n"
    "                                ->load(cuda::memory_order_consume);\n"
    "            auto misses = __vcmpeq4(tags, 0) & 0x01010101;\n"
    "            while (misses) {\n"
    "              auto missShift = __ffs(misses) - 1;\n"
    "              if (!bucket->addNewTag(tagWord, tags, missShift)) {\n"
    "                goto reprobe;\n"
    "              }\n"
    "              bucket->store(missShift / 8, row);\n"
    "              goto next;\n"
    "            }\n"
    "            bucketIdx = (bucketIdx + 1) & sizeMask;\n"
    "          }\n"
    "        }\n"
    "      next:;\n"
    "      }\n"
    "    }\n"
    "    __syncthreads();\n"
    "  }\n"
    "\n"
    "  template <typename RowType, typename Ops>\n"
    "  void __device__ joinBuild(RowType* rows, int32_t numRows, Ops ops) {\n"
    "    auto stride = blockDim.x * gridDim.x;\n"
    "    for (auto idx = threadIdx.x + blockDim.x * blockIdx.x; idx < numRows;\n"
    "         idx += stride) {\n"
    "      auto* row = rows + idx;\n"
    "      uint64_t h = ops.hashRow(row);\n"
    "      auto bucketIdx = h & sizeMask;\n"
    "      uint32_t tagWord = hashTag(h);\n"
    "      tagWord |= tagWord << 8;\n"
    "      tagWord = tagWord | tagWord << 16;\n"
    "\n"
    "      for (;;) {\n"
    "        GpuBucket* bucket = buckets + bucketIdx;\n"
    "      reprobe:\n"
    "        uint32_t tags = asDeviceAtomic<uint32_t>(&bucket->tags)\n"
    "                            ->load(cuda::memory_order_consume);\n"
    "        auto hits = __vcmpeq4(tags, tagWord) & 0x01010101;\n"
    "        while (hits) {\n"
    "          auto hitIdx = (__ffs(hits) - 1) / 8;\n"
    "          auto candidate = bucket->loadWithWait<RowType>(hitIdx);\n"
    "          if (ops.compare(row, candidate)) {\n"
    "            for (;;) {\n"
    "              auto previous = asDeviceAtomic<RowType*>(candidate->nextPtr())\n"
    "                                  ->load(cuda::memory_order_relaxed);\n"
    "              if ((unsigned long long)previous ==\n"
    "                  atomicCAS(\n"
    "                      (unsigned long long*)&candidate->next,\n"
    "                      (unsigned long long)previous,\n"
    "                      (unsigned long long)row)) {\n"
    "                *row->nextPtr() = previous;\n"
    "                // Set duplicates flag, no need to set if already set.\n"
    "                atomicCAS(&hasDuplicates, 0, 1);\n"
    "                goto next;\n"
    "              }\n"
    "            }\n"
    "          }\n"
    "          hits &= hits - 1;\n"
    "        }\n"
    "        auto misses = __vcmpeq4(tags, 0) & 0x01010101;\n"
    "        if (misses) {\n"
    "          auto missShift = __ffs(misses) - 1;\n"
    "          if (!bucket->addNewTag(tagWord, tags, missShift)) {\n"
    "            goto reprobe;\n"
    "          }\n"
    "          bucket->store(missShift / 8, row);\n"
    "          goto next;\n"
    "        }\n"
    "\n"
    "        bucketIdx = (bucketIdx + 1) & sizeMask;\n"
    "      }\n"
    "    next:;\n"
    "    }\n"
    "    __syncthreads();\n"
    "  }\n"
    "\n"
    "  int32_t __device__ partitionIdx(uint64_t h) const {\n"
    "    return partitionMask == 0 ? 0 : (h >> 41) & partitionMask;\n"
    "  }\n"
    "};\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_HashTable_cuh_reg =
    registerHeader(velox_experimental_wave_common_HashTable_cuh);
const char* velox_experimental_wave_common_hash_cuh =
    "velox/experimental/wave/common/hash.cuh\n";
bool velox_experimental_wave_common_hash_cuh_reg =
    registerHeader(velox_experimental_wave_common_hash_cuh);
const char* velox_experimental_wave_common_StringView_cuh =
    "velox/experimental/wave/common/StringView.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include \"velox/experimental/wave/common/StringView.h\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "__device__ StringView StringView::cas(StringView compare, StringView val) {\n"
    "  StringView old;\n"
    "  old.data_ = atomicCAS(&data_, compare.data_, val.data_);\n"
    "  return old;\n"
    "}\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_StringView_cuh_reg =
    registerHeader(velox_experimental_wave_common_StringView_cuh);
const char* velox_experimental_wave_common_StringView_h =
    "velox/experimental/wave/common/StringView.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <assert.h>\n"
    "#include <stdint.h>\n"
    "#include \"velox/experimental/wave/common/CompilerDefines.h\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "WAVE_DEVICE_HOST inline int\n"
    "stringview_memcmp(const void* lhs, const void* rhs, size_t n) {\n"
    "  auto* a = reinterpret_cast<const unsigned char*>(lhs);\n"
    "  auto* b = reinterpret_cast<const unsigned char*>(rhs);\n"
    "  for (size_t i = 0; i < n; ++i) {\n"
    "    if (int c = (int)a[i] - (int)b[i]) {\n"
    "      return c;\n"
    "    }\n"
    "  }\n"
    "  return 0;\n"
    "}\n"
    "\n"
    "class StringView {\n"
    " public:\n"
    "  WAVE_DEVICE_HOST void init(const char* data, int32_t len) {\n"
    "    data_ = len;\n"
    "    if (len == 0) {\n"
    "      return;\n"
    "    }\n"
    "    assert(len > 0);\n"
    "    if (len <= kInlineSize) {\n"
    "      memcpy(inlineData(), data, len);\n"
    "    } else {\n"
    "      assert(len <= kMaxSize);\n"
    "      assert(!((uintptr_t)data >> (64 - kSizeBits)));\n"
    "      data_ |= reinterpret_cast<uintptr_t>(data) << kSizeBits;\n"
    "    }\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST uint16_t size() const {\n"
    "    return data_ & kMaxSize;\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST bool isInline() const {\n"
    "    return size() <= kInlineSize;\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST const char* data() const {\n"
    "    if (isInline()) {\n"
    "      return const_cast<StringView*>(this)->inlineData();\n"
    "    }\n"
    "    return reinterpret_cast<const char*>(data_ >> kSizeBits);\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST bool operator==(StringView other) const {\n"
    "    if (isInline()) {\n"
    "      return data_ == other.data_;\n"
    "    }\n"
    "    auto len = size();\n"
    "    return len == other.size() &&\n"
    "        stringview_memcmp(data(), other.data(), len) == 0;\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST bool operator!=(StringView other) const {\n"
    "    return !(*this == other);\n"
    "  }\n"
    "\n"
    "#ifdef __NVCC__\n"
    "  __device__ StringView cas(StringView compare, StringView val);\n"
    "#endif\n"
    "#if !defined(__NVCC__) && !defined(__CUDACC_RTC__)\n"
    "  operator std::string_view() const {\n"
    "    return {data(), size()};\n"
    "  }\n"
    "#endif\n"
    "\n"
    " private:\n"
    "  WAVE_DEVICE_HOST char* inlineData() {\n"
    "    return reinterpret_cast<char*>(&data_) + kSizeBits / 8;\n"
    "  }\n"
    "\n"
    "  static constexpr int kSizeBits = 16;\n"
    "  static constexpr uint64_t kMaxSize = (1ull << kSizeBits) - 1;\n"
    "  static constexpr int kInlineSize = 8 - kSizeBits / 8;\n"
    "\n"
    "  unsigned long long data_;\n"
    "};\n"
    "\n"
    "// Non-trivial class does not play well in device code.\n"
    "#ifndef __CUDACC_RTC__\n"
    "static_assert(std::is_trivial_v<StringView>);\n"
    "#endif\n"
    "// Ensure StringView is 64 bits so we can do atomic operations on it.\n"
    "static_assert(sizeof(StringView) == 8);\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_StringView_h_reg =
    registerHeader(velox_experimental_wave_common_StringView_h);
const char* velox_experimental_wave_common_Hash_h =
    "velox/experimental/wave/common/Hash.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include \"velox/experimental/wave/common/CompilerDefines.h\"\n"
    "#include \"velox/experimental/wave/common/StringView.h\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "template <typename Input, typename Output>\n"
    "struct Hasher;\n"
    "\n"
    "class Murmur3 {\n"
    " public:\n"
    "  WAVE_DEVICE_HOST static uint32_t\n"
    "  hashBytes(const char* data, size_t len, uint32_t seed) {\n"
    "    auto h1 = seed;\n"
    "    size_t i = 0;\n"
    "    for (; i + 4 <= len; i += 4) {\n"
    "      uint32_t k1;\n"
    "      memcpy(&k1, data + i, sizeof(uint32_t));\n"
    "      h1 = mixH1(h1, mixK1(k1));\n"
    "    }\n"
    "    for (; i < len; ++i) {\n"
    "      h1 = mixH1(h1, mixK1(data[i]));\n"
    "    }\n"
    "    return fmix(h1, len);\n"
    "  }\n"
    "\n"
    " private:\n"
    "  WAVE_DEVICE_HOST static uint32_t rotl32(uint32_t a, int shift) {\n"
    "    return (a << shift) | (a >> (32 - shift));\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST static uint32_t mixK1(uint32_t k1) {\n"
    "    k1 *= 0xcc9e2d51;\n"
    "    k1 = rotl32(k1, 15);\n"
    "    k1 *= 0x1b873593;\n"
    "    return k1;\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST static uint32_t mixH1(uint32_t h1, uint32_t k1) {\n"
    "    h1 ^= k1;\n"
    "    h1 = rotl32(h1, 13);\n"
    "    h1 = h1 * 5 + 0xe6546b64;\n"
    "    return h1;\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST static uint32_t fmix(uint32_t h1, uint32_t length) {\n"
    "    h1 ^= length;\n"
    "    h1 ^= h1 >> 16;\n"
    "    h1 *= 0x85ebca6b;\n"
    "    h1 ^= h1 >> 13;\n"
    "    h1 *= 0xc2b2ae35;\n"
    "    h1 ^= h1 >> 16;\n"
    "    return h1;\n"
    "  }\n"
    "};\n"
    "\n"
    "WAVE_DEVICE_HOST inline uint32_t jenkinsRevMix32(uint32_t key) {\n"
    "  key += (key << 12); // key *= (1 + (1 << 12))\n"
    "  key ^= (key >> 22);\n"
    "  key += (key << 4); // key *= (1 + (1 << 4))\n"
    "  key ^= (key >> 9);\n"
    "  key += (key << 10); // key *= (1 + (1 << 10))\n"
    "  key ^= (key >> 2);\n"
    "  // key *= (1 + (1 << 7)) * (1 + (1 << 12))\n"
    "  key += (key << 7);\n"
    "  key += (key << 12);\n"
    "  return key;\n"
    "}\n"
    "\n"
    "WAVE_DEVICE_HOST inline uint32_t twang32From64(uint64_t key) {\n"
    "  key = (~key) + (key << 18);\n"
    "  key = key ^ (key >> 31);\n"
    "  key = key * 21;\n"
    "  key = key ^ (key >> 11);\n"
    "  key = key + (key << 6);\n"
    "  key = key ^ (key >> 22);\n"
    "  return static_cast<uint32_t>(key);\n"
    "}\n"
    "\n"
    "WAVE_DEVICE_HOST inline uint64_t hashMix(\n"
    "    const uint64_t upper,\n"
    "    const uint64_t lower) {\n"
    "  // Murmur-inspired hashing.\n"
    "  const uint64_t kMul = 0x9ddfea08eb382d69ULL;\n"
    "  uint64_t a = (lower ^ upper) * kMul;\n"
    "  a ^= (a >> 47);\n"
    "  uint64_t b = (upper ^ a) * kMul;\n"
    "  b ^= (b >> 47);\n"
    "  b *= kMul;\n"
    "  return b;\n"
    "}\n"
    "\n"
    "template <typename T>\n"
    "struct IntHasher32 {\n"
    "  WAVE_DEVICE_HOST uint32_t operator()(T val) const {\n"
    "    if constexpr (sizeof(T) <= 4) {\n"
    "      return jenkinsRevMix32(val);\n"
    "    } else {\n"
    "      return twang32From64(val);\n"
    "    }\n"
    "    __builtin_unreachable();\n"
    "  }\n"
    "};\n"
    "\n"
    "template <>\n"
    "struct Hasher<StringView, uint32_t> {\n"
    "  WAVE_DEVICE_HOST uint32_t operator()(StringView val) const {\n"
    "    return Murmur3::hashBytes(val.data(), val.size(), 42);\n"
    "  }\n"
    "};\n"
    "\n"
    "template <>\n"
    "struct Hasher<int32_t, uint32_t> : IntHasher32<int32_t> {};\n"
    "\n"
    "template <>\n"
    "struct Hasher<int64_t, uint32_t> : IntHasher32<int64_t> {};\n"
    "\n"
    "template <>\n"
    "struct Hasher<int64_t, uint64_t> {\n"
    "  WAVE_DEVICE_HOST uint64_t operator()(int64_t x) const {\n"
    "    auto k = x * 0xcc9e2d51;\n"
    "    return k ^ (x >> 13) ^ (k << 12);\n"
    "  }\n"
    "};\n"
    "\n"
    "template <typename T>\n"
    "inline WAVE_DEVICE_HOST uint64_t hashValue(T value) {\n"
    "  return Hasher<T, uint64_t>()(value);\n"
    "}\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_Hash_h_reg =
    registerHeader(velox_experimental_wave_common_Hash_h);
const char* velox_experimental_wave_common_CompilerDefines_h =
    "velox/experimental/wave/common/CompilerDefines.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "// Macro to declare execution domain if the header can be included in a C++ file\n"
    "// that is not compiled with nvcc\n"
    "#if defined(__CUDACC_RTC__)\n"
    "#define WAVE_DEVICE_HOST __device__\n"
    "#elif defined(__CUDACC__)\n"
    "#define WAVE_DEVICE_HOST __device__ __host__\n"
    "#else\n"
    "#define WAVE_DEVICE_HOST\n"
    "#endif\n";
bool velox_experimental_wave_common_CompilerDefines_h_reg =
    registerHeader(velox_experimental_wave_common_CompilerDefines_h);
const char* velox_experimental_wave_common_Atomic_cuh =
    "velox/experimental/wave/common/Atomic.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <breeze/platforms/platform.h>\n"
    "#include <breeze/utils/types.h>\n"
    "#include <breeze/platforms/cuda.cuh>\n"
    "#include \"velox/experimental/wave/common/BitUtil.cuh\"\n"
    "#include \"velox/experimental/wave/common/CompilerDefines.h\"\n"
    "\n"
    "namespace facebook::velox::wave {\n"
    "\n"
    "enum class MemoryScope { kDevice };\n"
    "\n"
    "enum class MemoryOrder {\n"
    "  kRelaxed,\n"
    "  kAcquire,\n"
    "  kRelease,\n"
    "};\n"
    "\n"
    "template <MemoryScope>\n"
    "struct MemoryScopeTraits;\n"
    "\n"
    "template <>\n"
    "struct MemoryScopeTraits<MemoryScope::kDevice> {\n"
    "  WAVE_DEVICE_HOST static constexpr breeze::utils::AddressSpace\n"
    "  breezeAddressSpace() {\n"
    "    return breeze::utils::GLOBAL;\n"
    "  }\n"
    "};\n"
    "\n"
    "template <MemoryOrder>\n"
    "struct MemoryOrderTraits;\n"
    "\n"
    "template <>\n"
    "struct MemoryOrderTraits<MemoryOrder::kRelaxed> {\n"
    "  WAVE_DEVICE_HOST static constexpr breeze::utils::MemoryOrder breezeType() {\n"
    "    return breeze::utils::RELAXED;\n"
    "  }\n"
    "};\n"
    "\n"
    "template <>\n"
    "struct MemoryOrderTraits<MemoryOrder::kAcquire> {\n"
    "  WAVE_DEVICE_HOST static constexpr breeze::utils::MemoryOrder breezeType() {\n"
    "    return breeze::utils::ACQUIRE;\n"
    "  }\n"
    "};\n"
    "\n"
    "template <>\n"
    "struct MemoryOrderTraits<MemoryOrder::kRelease> {\n"
    "  WAVE_DEVICE_HOST static constexpr breeze::utils::MemoryOrder breezeType() {\n"
    "    return breeze::utils::RELEASE;\n"
    "  }\n"
    "};\n"
    "\n"
    "template <typename T, MemoryScope Scope = MemoryScope::kDevice>\n"
    "struct Atomic {\n"
    " private:\n"
    "  // Note: We currently do not require usage of this class to specify\n"
    "  // kBlockThreads as that is not technically used by the CUDA platform\n"
    "  // implementation of the atomics in breeze and is non-trivial for\n"
    "  // existing usage of atomics to provide. We might want to change that\n"
    "  // in the future to remove that assumption.\n"
    "  using PlatformT = CudaPlatform</*kBlockThreads=*/kWarpThreads, kWarpThreads>;\n"
    "\n"
    "  T value_;\n"
    "\n"
    " public:\n"
    "  Atomic(const Atomic&) = delete;\n"
    "  Atomic& operator=(const Atomic&) = delete;\n"
    "  Atomic& operator=(const Atomic&) volatile = delete;\n"
    "\n"
    "  WAVE_DEVICE_HOST explicit constexpr Atomic() noexcept = default;\n"
    "  WAVE_DEVICE_HOST constexpr explicit inline Atomic(T value) noexcept\n"
    "      : value_(value) {}\n"
    "\n"
    "  template <MemoryOrder Order = MemoryOrder::kRelaxed>\n"
    "  WAVE_DEVICE_HOST T load() noexcept {\n"
    "    using namespace breeze::utils;\n"
    "    return PlatformT().atomic_load<MemoryOrderTraits<Order>::breezeType()>(\n"
    "        make_slice<MemoryScopeTraits<Scope>::breezeAddressSpace()>(&value_));\n"
    "  }\n"
    "\n"
    "  template <MemoryOrder Order = MemoryOrder::kRelaxed>\n"
    "  WAVE_DEVICE_HOST void store(T value) noexcept {\n"
    "    using namespace breeze::utils;\n"
    "    PlatformT().atomic_store<MemoryOrderTraits<Order>::breezeType()>(\n"
    "        make_slice<MemoryScopeTraits<Scope>::breezeAddressSpace()>(&value_),\n"
    "        value);\n"
    "  }\n"
    "\n"
    "  template <MemoryOrder Order = MemoryOrder::kRelaxed>\n"
    "  WAVE_DEVICE_HOST bool compare_exchange(T& expected, T desired) noexcept {\n"
    "    using namespace breeze::utils;\n"
    "    T actual = PlatformT().atomic_cas<MemoryOrderTraits<Order>::breezeType()>(\n"
    "        make_slice<MemoryScopeTraits<Scope>::breezeAddressSpace()>(&value_),\n"
    "        expected,\n"
    "        desired);\n"
    "    bool was_changed = actual == expected;\n"
    "    expected = actual;\n"
    "    return was_changed;\n"
    "  }\n"
    "};\n"
    "\n"
    "template <MemoryScope Scope = MemoryScope::kDevice>\n"
    "struct AtomicMutex {\n"
    " private:\n"
    "  Atomic<int, Scope> value_;\n"
    "\n"
    " public:\n"
    "  AtomicMutex(const AtomicMutex&) = delete;\n"
    "  AtomicMutex& operator=(const AtomicMutex&) = delete;\n"
    "  AtomicMutex& operator=(const AtomicMutex&) volatile = delete;\n"
    "\n"
    "  WAVE_DEVICE_HOST explicit constexpr AtomicMutex() noexcept = default;\n"
    "  WAVE_DEVICE_HOST constexpr explicit inline AtomicMutex(int value) noexcept\n"
    "      : value_(value) {}\n"
    "\n"
    "  WAVE_DEVICE_HOST void acquire() {\n"
    "    constexpr int kPollingCount = 5;\n"
    "    constexpr int kInitialBackoffStepNs = 10;\n"
    "\n"
    "    int count = 0;\n"
    "    auto step_ns = kInitialBackoffStepNs + threadIdx.x % 32;\n"
    "    for (;;) {\n"
    "      int available = value_.template load<MemoryOrder::kAcquire>();\n"
    "      while (available) {\n"
    "        if (value_.template compare_exchange<MemoryOrder::kAcquire>(\n"
    "                available, 0)) {\n"
    "          return;\n"
    "        }\n"
    "      }\n"
    "      if (count < kPollingCount) {\n"
    "        count += 1;\n"
    "        continue;\n"
    "      }\n"
    "#if defined(__CUDA_ARCH__)\n"
    "      __nanosleep(step_ns);\n"
    "      step_ns *= 2;\n"
    "#endif\n"
    "    }\n"
    "  }\n"
    "\n"
    "  WAVE_DEVICE_HOST void release() {\n"
    "    value_.template store<MemoryOrder::kRelease>(1);\n"
    "  }\n"
    "};\n"
    "\n"
    "} // namespace facebook::velox::wave\n";
bool velox_experimental_wave_common_Atomic_cuh_reg =
    registerHeader(velox_experimental_wave_common_Atomic_cuh);
const char* breeze_platforms_cuda_cuh =
    "breeze/platforms/cuda.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "/*\n"
    " * Copyright (c) 2024 by Rivos Inc.\n"
    " * Licensed under the Apache License, Version 2.0, see LICENSE for details.\n"
    " * SPDX-License-Identifier: Apache-2.0\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include <cuda_runtime_api.h>\n"
    "\n"
    "#include \"breeze/utils/types.h\"\n"
    "\n"
    "#if defined(__CUDA_ARCH__)\n"
    "#if __CUDA_ARCH__ < 500\n"
    "#error \"Unsupported CUDA architecture\"\n"
    "#endif\n"
    "#endif\n"
    "\n"
    "struct CudaSpecialization {\n"
    "  template <int WARP_THREADS>\n"
    "  static __device__ __forceinline__ int lane_idx();\n"
    "  template <int WARP_THREADS>\n"
    "  static __device__ __forceinline__ int warp_idx();\n"
    "  template <int WARP_THREADS>\n"
    "  static __device__ __forceinline__ unsigned lower_rank_lanemask();\n"
    "  template <int WARP_THREADS>\n"
    "  static __device__ __forceinline__ unsigned higher_rank_lanemask();\n"
    "  static __device__ __forceinline__ void reconvergence_hint() {\n"
    "    // Use full warp syncs as hint to encourage reconvergence by default.\n"
    "    __syncwarp();\n"
    "  }\n"
    "  template <breeze::utils::MemoryOrder MEMORY_ORDER = breeze::utils::RELAXED,\n"
    "            typename SliceT, typename T = typename SliceT::data_type>\n"
    "  static __device__ __forceinline__ T atomic_load(SliceT address) {\n"
    "    static_assert(MEMORY_ORDER != breeze::utils::RELEASE,\n"
    "                  \"RELEASE is not a valid memory order for atomic_load\");\n"
    "    T value;\n"
    "    if constexpr (SliceT::ADDRESS_SPACE == breeze::utils::GLOBAL) {\n"
    "      value = __ldcg(address.data());  // cache only globally\n"
    "    } else {\n"
    "      value = *address.data();\n"
    "    }\n"
    "    if constexpr (MEMORY_ORDER == breeze::utils::ACQUIRE) {\n"
    "      __threadfence();\n"
    "    }\n"
    "    return value;\n"
    "  }\n"
    "  template <breeze::utils::MemoryOrder MEMORY_ORDER = breeze::utils::RELAXED,\n"
    "            typename SliceT, typename T = typename SliceT::data_type>\n"
    "  static __device__ __forceinline__ void atomic_store(SliceT address, T value) {\n"
    "    static_assert(MEMORY_ORDER != breeze::utils::ACQUIRE,\n"
    "                  \"ACQUIRE is not a valid memory order for atomic_store\");\n"
    "    if constexpr (MEMORY_ORDER == breeze::utils::RELEASE) {\n"
    "      __threadfence();\n"
    "    }\n"
    "    if constexpr (SliceT::ADDRESS_SPACE == breeze::utils::GLOBAL) {\n"
    "      __stcg(address.data(), value);  // cache only globally\n"
    "    } else {\n"
    "      *address.data() = value;\n"
    "    }\n"
    "  }\n"
    "  template <breeze::utils::MemoryOrder MEMORY_ORDER = breeze::utils::RELAXED,\n"
    "            typename SliceT, typename T = typename SliceT::data_type>\n"
    "  static __device__ __forceinline__ T atomic_cas(SliceT address, T compare,\n"
    "                                                 T value) {\n"
    "    if constexpr (MEMORY_ORDER == breeze::utils::RELEASE) {\n"
    "      __threadfence();\n"
    "    }\n"
    "    T old = atomicCAS(address.data(), compare, value);\n"
    "    if constexpr (MEMORY_ORDER == breeze::utils::ACQUIRE) {\n"
    "      __threadfence();\n"
    "    }\n"
    "    return old;\n"
    "  }\n"
    "  template <typename SliceT, typename T = typename SliceT::data_type>\n"
    "  static __device__ __forceinline__ T atomic_add(SliceT address, T value) {\n"
    "    return atomicAdd(address.data(), value);\n"
    "  }\n"
    "  template <typename SliceT, typename T = typename SliceT::data_type>\n"
    "  static __device__ __forceinline__ void atomic_min(SliceT address, T value) {\n"
    "    atomicMin(address.data(), value);\n"
    "  }\n"
    "  template <typename SliceT, typename T = typename SliceT::data_type>\n"
    "  static __device__ __forceinline__ void atomic_max(SliceT address, T value) {\n"
    "    atomicMax(address.data(), value);\n"
    "  }\n"
    "  template <int WARP_THREADS, typename T>\n"
    "  static __device__ __forceinline__ T reduce_add(T value) {\n"
    "#pragma unroll\n"
    "    for (int offset = WARP_THREADS / 2; offset > 0; offset /= 2) {\n"
    "      value += __shfl_xor_sync(0xffffffff, value, offset);\n"
    "    }\n"
    "    return value;\n"
    "  }\n"
    "  template <int WARP_THREADS, typename T>\n"
    "  static __device__ __forceinline__ T reduce_min(T value) {\n"
    "#pragma unroll\n"
    "    for (int offset = WARP_THREADS / 2; offset > 0; offset /= 2) {\n"
    "      value = ::min(value, __shfl_xor_sync(0xffffffff, value, offset));\n"
    "    }\n"
    "    return value;\n"
    "  }\n"
    "  template <int WARP_THREADS, typename T>\n"
    "  static __device__ __forceinline__ T reduce_max(T value) {\n"
    "#pragma unroll\n"
    "    for (int offset = WARP_THREADS / 2; offset > 0; offset /= 2) {\n"
    "      value = ::max(value, __shfl_xor_sync(0xffffffff, value, offset));\n"
    "    }\n"
    "    return value;\n"
    "  }\n"
    "  template <int WARP_THREADS, typename T>\n"
    "  static __device__ __forceinline__ T scan_add(T value) {\n"
    "#pragma unroll\n"
    "    for (int offset = 1; offset < WARP_THREADS; offset <<= 1) {\n"
    "      T result = __shfl_up_sync(0xffffffff, value, offset);\n"
    "      if ((threadIdx.x % WARP_THREADS) >= offset) {\n"
    "        value += result;\n"
    "      }\n"
    "    }\n"
    "    return value;\n"
    "  }\n"
    "  template <typename T>\n"
    "  static __device__ __forceinline__ int count_leading_zeros(T value);\n"
    "  template <typename T>\n"
    "  static __device__ __forceinline__ int population_count(T value);\n"
    "  template <typename T>\n"
    "  static __device__ __forceinline__ T extract_bits(T value, int start_bit,\n"
    "                                                   int num_bits);\n"
    "  template <bool HIGH_PRIORITY>\n"
    "  static __device__ __forceinline__ void scheduling_hint() {}\n"
    "  template <typename SliceT>\n"
    "  static __device__ __forceinline__ void prefetch(SliceT) {}\n"
    "};\n"
    "\n"
    "template <int CUDA_BLOCK_THREADS, int CUDA_WARP_THREADS>\n"
    "struct CudaPlatform {\n"
    "  enum {\n"
    "    BLOCK_THREADS = CUDA_BLOCK_THREADS,\n"
    "    WARP_THREADS = CUDA_WARP_THREADS,\n"
    "  };\n"
    "  __device__ __forceinline__ int thread_idx() { return threadIdx.x; }\n"
    "  __device__ __forceinline__ int block_idx() { return blockIdx.x; }\n"
    "  __device__ __forceinline__ void syncthreads() { __syncthreads(); }\n"
    "  __device__ __forceinline__ void syncwarp() { __syncwarp(); }\n"
    "  __device__ __forceinline__ int lane_idx() {\n"
    "    return CudaSpecialization::template lane_idx<WARP_THREADS>();\n"
    "  }\n"
    "  __device__ __forceinline__ int warp_idx() {\n"
    "    return CudaSpecialization::template warp_idx<WARP_THREADS>();\n"
    "  }\n"
    "  __device__ __forceinline__ unsigned lower_rank_lanemask() {\n"
    "    return CudaSpecialization::template lower_rank_lanemask<WARP_THREADS>();\n"
    "  }\n"
    "  __device__ __forceinline__ unsigned higher_rank_lanemask() {\n"
    "    return CudaSpecialization::template higher_rank_lanemask<WARP_THREADS>();\n"
    "  }\n"
    "  __device__ __forceinline__ void reconvergence_hint() {\n"
    "    CudaSpecialization::reconvergence_hint();\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T min(T lhs, T rhs) {\n"
    "    return ::min(lhs, rhs);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T max(T lhs, T rhs) {\n"
    "    return ::max(lhs, rhs);\n"
    "  }\n"
    "  template <breeze::utils::MemoryOrder MEMORY_ORDER = breeze::utils::RELAXED,\n"
    "            typename SliceT, typename T = typename SliceT::data_type>\n"
    "  __device__ __forceinline__ T atomic_load(SliceT address) {\n"
    "    return CudaSpecialization::atomic_load<MEMORY_ORDER>(address);\n"
    "  }\n"
    "  template <breeze::utils::MemoryOrder MEMORY_ORDER = breeze::utils::RELAXED,\n"
    "            typename SliceT, typename T = typename SliceT::data_type>\n"
    "  __device__ __forceinline__ void atomic_store(SliceT address, T value) {\n"
    "    CudaSpecialization::atomic_store<MEMORY_ORDER>(address, value);\n"
    "  }\n"
    "  template <breeze::utils::MemoryOrder MEMORY_ORDER = breeze::utils::RELAXED,\n"
    "            typename SliceT, typename T = typename SliceT::data_type>\n"
    "  __device__ __forceinline__ T atomic_cas(SliceT address, T compare, T value) {\n"
    "    return CudaSpecialization::atomic_cas<MEMORY_ORDER>(address, compare,\n"
    "                                                        value);\n"
    "  }\n"
    "  template <typename SliceT, typename T = typename SliceT::data_type>\n"
    "  __device__ __forceinline__ T atomic_add(SliceT address, T value) {\n"
    "    return CudaSpecialization::atomic_add(address, value);\n"
    "  }\n"
    "  template <typename SliceT, typename T = typename SliceT::data_type>\n"
    "  __device__ __forceinline__ void atomic_min(SliceT address, T value) {\n"
    "    CudaSpecialization::atomic_min(address, value);\n"
    "  }\n"
    "  template <typename SliceT, typename T = typename SliceT::data_type>\n"
    "  __device__ __forceinline__ void atomic_max(SliceT address, T value) {\n"
    "    CudaSpecialization::atomic_max(address, value);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T reduce_add(T value) {\n"
    "    return CudaSpecialization::template reduce_add<WARP_THREADS, T>(value);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T reduce_min(T value) {\n"
    "    return CudaSpecialization::template reduce_min<WARP_THREADS, T>(value);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T reduce_max(T value) {\n"
    "    return CudaSpecialization::template reduce_max<WARP_THREADS, T>(value);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T scan_add(T value) {\n"
    "    return CudaSpecialization::template scan_add<WARP_THREADS, T>(value);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T ballot(bool value) {\n"
    "    return __ballot_sync(0xffffffff, value);\n"
    "  }\n"
    "  template <int MIN_BITS, typename T>\n"
    "  __device__ __forceinline__ unsigned match_any(T value) {\n"
    "#if __CUDA_ARCH__ >= 800\n"
    "    return __match_any_sync(0xffffffff, value);\n"
    "#else\n"
    "    unsigned result;\n"
    "#pragma unroll\n"
    "    for (unsigned i = 0; i < MIN_BITS; ++i) {\n"
    "      unsigned current_bit = 1 << i;\n"
    "      bool pred = (value & current_bit) == current_bit;\n"
    "      unsigned mask = __ballot_sync(0xffffffff, pred);\n"
    "      if (!pred) {\n"
    "        mask = ~mask;\n"
    "      }\n"
    "      result = (i == 0) ? mask : result & mask;\n"
    "    }\n"
    "    return result;\n"
    "#endif\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ int count_leading_zeros(T value) {\n"
    "    return CudaSpecialization::count_leading_zeros(value);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ int population_count(T value) {\n"
    "    return CudaSpecialization::population_count(value);\n"
    "  }\n"
    "  template <typename T>\n"
    "  __device__ __forceinline__ T extract_bits(T value, int start_bit,\n"
    "                                            int num_bits) {\n"
    "    return CudaSpecialization::extract_bits(value, start_bit, num_bits);\n"
    "  }\n"
    "  template <bool HIGH_PRIORITY>\n"
    "  __device__ __forceinline__ void scheduling_hint() {\n"
    "    CudaSpecialization::scheduling_hint<HIGH_PRIORITY>();\n"
    "  }\n"
    "  template <typename SliceT>\n"
    "  __device__ __forceinline__ void prefetch(SliceT address) {\n"
    "    CudaSpecialization::prefetch(address);\n"
    "  }\n"
    "};\n"
    "\n"
    "#if CUDART_VERSION >= 12080 && __CUDA_ARCH__ >= 700\n"
    "// specialization for MEMORY_ORDER=ACQUIRE, SliceT=Slice<GLOBAL, BLOCKED, int>\n"
    "template <>\n"
    "__device__ __forceinline__ int CudaSpecialization::atomic_load<\n"
    "    breeze::utils::ACQUIRE,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, int>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, int>\n"
    "        address) {\n"
    "  return __nv_atomic_load_n(address.data(), __NV_ATOMIC_ACQUIRE,\n"
    "                            __NV_THREAD_SCOPE_DEVICE);\n"
    "}\n"
    "\n"
    "// specialization for MEMORY_ORDER=ACQUIRE, SliceT=Slice<GLOBAL, BLOCKED,\n"
    "// unsigned>\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned CudaSpecialization::atomic_load<\n"
    "    breeze::utils::ACQUIRE,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         unsigned>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         unsigned>\n"
    "        address) {\n"
    "  return __nv_atomic_load_n(address.data(), __NV_ATOMIC_ACQUIRE,\n"
    "                            __NV_THREAD_SCOPE_DEVICE);\n"
    "}\n"
    "\n"
    "// specialization for MEMORY_ORDER=ACQUIRE, SliceT=Slice<GLOBAL, BLOCKED, int>\n"
    "template <>\n"
    "__device__ __forceinline__ void CudaSpecialization::atomic_store<\n"
    "    breeze::utils::ACQUIRE,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, int>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, int>\n"
    "        address,\n"
    "    int value) {\n"
    "  __nv_atomic_store_n(address.data(), value, __NV_ATOMIC_RELEASE,\n"
    "                      __NV_THREAD_SCOPE_DEVICE);\n"
    "}\n"
    "\n"
    "// specialization for MEMORY_ORDER=ACQUIRE, SliceT=Slice<GLOBAL, BLOCKED,\n"
    "// unsigned>\n"
    "template <>\n"
    "__device__ __forceinline__ void CudaSpecialization::atomic_store<\n"
    "    breeze::utils::ACQUIRE,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         unsigned>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         unsigned>\n"
    "        address,\n"
    "    unsigned value) {\n"
    "  __nv_atomic_store_n(address.data(), value, __NV_ATOMIC_RELEASE,\n"
    "                      __NV_THREAD_SCOPE_DEVICE);\n"
    "}\n"
    "\n"
    "// specialization for MEMORY_ORDER=ACQUIRE, SliceT=Slice<GLOBAL, BLOCKED,\n"
    "// unsigned>\n"
    "template <>\n"
    "__device__ __forceinline__ int CudaSpecialization::atomic_cas<\n"
    "    breeze::utils::ACQUIRE,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, int>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, int>\n"
    "        address,\n"
    "    int compare, int value) {\n"
    "  int expected = compare;\n"
    "  __nv_atomic_compare_exchange_n(address.data(), &expected, value, false,\n"
    "                                 __NV_ATOMIC_ACQUIRE, __NV_ATOMIC_ACQUIRE,\n"
    "                                 __NV_THREAD_SCOPE_DEVICE);\n"
    "  return expected;\n"
    "}\n"
    "\n"
    "// specialization for MEMORY_ORDER=ACQUIRE, SliceT=Slice<GLOBAL, BLOCKED,\n"
    "// unsigned>\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned CudaSpecialization::atomic_cas<\n"
    "    breeze::utils::ACQUIRE,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         unsigned>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         unsigned>\n"
    "        address,\n"
    "    unsigned compare, unsigned value) {\n"
    "  unsigned expected = compare;\n"
    "  __nv_atomic_compare_exchange_n(address.data(), &expected, value, false,\n"
    "                                 __NV_ATOMIC_ACQUIRE, __NV_ATOMIC_ACQUIRE,\n"
    "                                 __NV_THREAD_SCOPE_DEVICE);\n"
    "  return expected;\n"
    "}\n"
    "#endif\n"
    "\n"
    "#if __CUDA_ARCH__ < 600\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, double>\n"
    "template <>\n"
    "__device__ __forceinline__ double\n"
    "CudaSpecialization::atomic_add<breeze::utils::Slice<\n"
    "    breeze::utils::GLOBAL, breeze::utils::BLOCKED, double>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, double>\n"
    "        address,\n"
    "    double value) {\n"
    "  static_assert(sizeof(double) == sizeof(unsigned long long),\n"
    "                \"unexpected type sizes\");\n"
    "  unsigned long long old =\n"
    "      *reinterpret_cast<unsigned long long *>(address.data());\n"
    "  unsigned long long assumed;\n"
    "  do {\n"
    "    assumed = old;\n"
    "    old = atomicCAS(\n"
    "        reinterpret_cast<unsigned long long *>(address.data()), assumed,\n"
    "        __double_as_longlong(value + __longlong_as_double(assumed)));\n"
    "  } while (assumed != old);\n"
    "\n"
    "  return __longlong_as_double(old);\n"
    "}\n"
    "#endif\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, long long>\n"
    "template <>\n"
    "__device__ __forceinline__ long long\n"
    "CudaSpecialization::atomic_add<breeze::utils::Slice<\n"
    "    breeze::utils::GLOBAL, breeze::utils::BLOCKED, long long>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         long long>\n"
    "        address,\n"
    "    long long value) {\n"
    "  unsigned long long result =\n"
    "      atomicAdd(reinterpret_cast<unsigned long long *>(address.data()),\n"
    "                *reinterpret_cast<unsigned long long *>(&value));\n"
    "  return *reinterpret_cast<long long *>(&result);\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, float>\n"
    "template <>\n"
    "__device__ __forceinline__ void CudaSpecialization::atomic_min<\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, float>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, float>\n"
    "        address,\n"
    "    float value) {\n"
    "  static_assert(sizeof(float) == sizeof(unsigned), \"unexpected type sizes\");\n"
    "  float current = atomic_load(address);\n"
    "  while (current > value) {\n"
    "    unsigned old = atomicCAS(reinterpret_cast<unsigned *>(address.data()),\n"
    "                             *reinterpret_cast<unsigned *>(&current),\n"
    "                             *reinterpret_cast<unsigned *>(&value));\n"
    "    current = *reinterpret_cast<float *>(&old);\n"
    "    if (current == value) {\n"
    "      break;\n"
    "    }\n"
    "  }\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, float>\n"
    "template <>\n"
    "__device__ __forceinline__ void CudaSpecialization::atomic_max<\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, float>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, float>\n"
    "        address,\n"
    "    float value) {\n"
    "  static_assert(sizeof(float) == sizeof(unsigned), \"unexpected type sizes\");\n"
    "  float current = atomic_load(address);\n"
    "  while (current < value) {\n"
    "    unsigned old = atomicCAS(reinterpret_cast<unsigned *>(address.data()),\n"
    "                             *reinterpret_cast<unsigned *>(&current),\n"
    "                             *reinterpret_cast<unsigned *>(&value));\n"
    "    current = *reinterpret_cast<float *>(&old);\n"
    "    if (current == value) {\n"
    "      break;\n"
    "    }\n"
    "  }\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, double>\n"
    "template <>\n"
    "__device__ __forceinline__ void\n"
    "CudaSpecialization::atomic_min<breeze::utils::Slice<\n"
    "    breeze::utils::GLOBAL, breeze::utils::BLOCKED, double>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, double>\n"
    "        address,\n"
    "    double value) {\n"
    "  static_assert(sizeof(double) == sizeof(unsigned long long),\n"
    "                \"unexpected type sizes\");\n"
    "  double current = atomic_load(address);\n"
    "  while (current > value) {\n"
    "    unsigned long long old =\n"
    "        atomicCAS(reinterpret_cast<unsigned long long *>(address.data()),\n"
    "                  *reinterpret_cast<unsigned long long *>(&current),\n"
    "                  *reinterpret_cast<unsigned long long *>(&value));\n"
    "    current = *reinterpret_cast<double *>(&old);\n"
    "    if (current == value) {\n"
    "      break;\n"
    "    }\n"
    "  }\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, double>\n"
    "template <>\n"
    "__device__ __forceinline__ void\n"
    "CudaSpecialization::atomic_max<breeze::utils::Slice<\n"
    "    breeze::utils::GLOBAL, breeze::utils::BLOCKED, double>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, double>\n"
    "        address,\n"
    "    double value) {\n"
    "  static_assert(sizeof(double) == sizeof(unsigned long long),\n"
    "                \"unexpected type sizes\");\n"
    "  double current = atomic_load(address);\n"
    "  while (current < value) {\n"
    "    unsigned long long old =\n"
    "        atomicCAS(reinterpret_cast<unsigned long long *>(address.data()),\n"
    "                  *reinterpret_cast<unsigned long long *>(&current),\n"
    "                  *reinterpret_cast<unsigned long long *>(&value));\n"
    "    current = *reinterpret_cast<double *>(&old);\n"
    "    if (current == value) {\n"
    "      break;\n"
    "    }\n"
    "  }\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, long long>\n"
    "template <>\n"
    "__device__ __forceinline__ long long CudaSpecialization::atomic_cas<\n"
    "    breeze::utils::RELAXED,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         long long>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         long long>\n"
    "        address,\n"
    "    long long compare, long long value) {\n"
    "  unsigned long long old =\n"
    "      atomicCAS(reinterpret_cast<unsigned long long *>(address.data()),\n"
    "                *reinterpret_cast<unsigned long long *>(&compare),\n"
    "                *reinterpret_cast<unsigned long long *>(&value));\n"
    "  return *reinterpret_cast<long long *>(&old);\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<SHARED, BLOCKED, long long>\n"
    "template <>\n"
    "__device__ __forceinline__ long long CudaSpecialization::atomic_cas<\n"
    "    breeze::utils::RELAXED,\n"
    "    breeze::utils::Slice<breeze::utils::SHARED, breeze::utils::BLOCKED,\n"
    "                         long long>>(\n"
    "    breeze::utils::Slice<breeze::utils::SHARED, breeze::utils::BLOCKED,\n"
    "                         long long>\n"
    "        address,\n"
    "    long long compare, long long value) {\n"
    "  using pointer_type =\n"
    "      typename breeze::utils::Slice<breeze::utils::SHARED,\n"
    "                                    breeze::utils::BLOCKED,\n"
    "                                    unsigned long long>::pointer_type;\n"
    "  unsigned long long old =\n"
    "      atomicCAS(reinterpret_cast<pointer_type>(address.data()),\n"
    "                *reinterpret_cast<unsigned long long *>(&compare),\n"
    "                *reinterpret_cast<unsigned long long *>(&value));\n"
    "  return *reinterpret_cast<long long *>(&old);\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, float>\n"
    "template <>\n"
    "__device__ __forceinline__ float CudaSpecialization::atomic_cas<\n"
    "    breeze::utils::RELAXED,\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, float>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED, float>\n"
    "        address,\n"
    "    float compare, float value) {\n"
    "  static_assert(sizeof(float) == sizeof(unsigned), \"unexpected type sizes\");\n"
    "  unsigned old = atomicCAS(reinterpret_cast<unsigned *>(address.data()),\n"
    "                           *reinterpret_cast<unsigned *>(&compare),\n"
    "                           *reinterpret_cast<unsigned *>(&value));\n"
    "  return *reinterpret_cast<float *>(&old);\n"
    "}\n"
    "\n"
    "#if __CUDA_ARCH__ >= 800\n"
    "// specialization for T=int\n"
    "template <>\n"
    "__device__ __forceinline__ int CudaSpecialization::reduce_add<32, int>(\n"
    "    int value) {\n"
    "  return __reduce_add_sync(0xffffffff, value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned\n"
    "CudaSpecialization::reduce_add<32, unsigned>(unsigned value) {\n"
    "  return __reduce_add_sync(0xffffffff, value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned long long\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned long long\n"
    "CudaSpecialization::reduce_add<32, unsigned long long>(\n"
    "    unsigned long long value) {\n"
    "  unsigned low =\n"
    "      __reduce_add_sync(0xffffffff, static_cast<unsigned>(value & 0xffffff));\n"
    "  unsigned mid = __reduce_add_sync(\n"
    "      0xffffffff, static_cast<unsigned>((value >> 24) & 0xffffff));\n"
    "  unsigned high =\n"
    "      __reduce_add_sync(0xffffffff, static_cast<unsigned>(value >> 48));\n"
    "  return low + (static_cast<unsigned long long>(mid) << 24) +\n"
    "         (static_cast<unsigned long long>(high) << 48);\n"
    "}\n"
    "\n"
    "// specialization for T=long long\n"
    "template <>\n"
    "__device__ __forceinline__ long long\n"
    "CudaSpecialization::reduce_add<32, long long>(long long value) {\n"
    "  return reduce_add<32, unsigned long long>(value);\n"
    "}\n"
    "\n"
    "// specialization for T=int\n"
    "template <>\n"
    "__device__ __forceinline__ int CudaSpecialization::reduce_min<32, int>(\n"
    "    int value) {\n"
    "  return __reduce_min_sync(0xffffffff, value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned\n"
    "CudaSpecialization::reduce_min<32, unsigned>(unsigned value) {\n"
    "  return __reduce_min_sync(0xffffffff, value);\n"
    "}\n"
    "\n"
    "// specialization for T=long long\n"
    "template <>\n"
    "__device__ __forceinline__ long long\n"
    "CudaSpecialization::reduce_min<32, long long>(long long value) {\n"
    "  int high = __reduce_min_sync(0xffffffff, static_cast<int>(value >> 32));\n"
    "  bool match_high = high == static_cast<int>(value >> 32);\n"
    "  // force threads that lost the first reduction to lose the second reduction\n"
    "  unsigned sel_low =\n"
    "      match_high ? static_cast<unsigned>(value & 0xffffffff) : ~0u;\n"
    "  unsigned low = __reduce_min_sync(0xffffffff, sel_low);\n"
    "  return low + (static_cast<long long>(high) << 32);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned long long\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned long long\n"
    "CudaSpecialization::reduce_min<32, unsigned long long>(\n"
    "    unsigned long long value) {\n"
    "  unsigned high =\n"
    "      __reduce_min_sync(0xffffffff, static_cast<unsigned>(value >> 32));\n"
    "  bool match_high = high == static_cast<unsigned>(value >> 32);\n"
    "  // force threads that lost the first reduction to lose the second reduction\n"
    "  unsigned sel_low =\n"
    "      match_high ? static_cast<unsigned>(value & 0xffffffff) : ~0u;\n"
    "  unsigned low = __reduce_min_sync(0xffffffff, sel_low);\n"
    "  return low + (static_cast<unsigned long long>(high) << 32);\n"
    "}\n"
    "\n"
    "// specialization for T=int\n"
    "template <>\n"
    "__device__ __forceinline__ int CudaSpecialization::reduce_max<32, int>(\n"
    "    int value) {\n"
    "  return __reduce_max_sync(0xffffffff, value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned\n"
    "CudaSpecialization::reduce_max<32, unsigned>(unsigned value) {\n"
    "  return __reduce_max_sync(0xffffffff, value);\n"
    "}\n"
    "\n"
    "// specialization for T=long long\n"
    "template <>\n"
    "__device__ __forceinline__ long long\n"
    "CudaSpecialization::reduce_max<32, long long>(long long value) {\n"
    "  int high = __reduce_max_sync(0xffffffff, static_cast<int>(value >> 32));\n"
    "  bool match_high = high == static_cast<int>(value >> 32);\n"
    "  // force threads that lost the first reduction to lose the second reduction\n"
    "  unsigned sel_low =\n"
    "      match_high ? static_cast<unsigned>(value & 0xffffffff) : 0u;\n"
    "  unsigned low = __reduce_max_sync(0xffffffff, sel_low);\n"
    "  return low + (static_cast<long long>(high) << 32);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned long long\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned long long\n"
    "CudaSpecialization::reduce_max<32, unsigned long long>(\n"
    "    unsigned long long value) {\n"
    "  unsigned high =\n"
    "      __reduce_max_sync(0xffffffff, static_cast<unsigned>(value >> 32));\n"
    "  bool match_high = high == static_cast<unsigned>(value >> 32);\n"
    "  // force threads that lost the first reduction to lose the second reduction\n"
    "  unsigned sel_low =\n"
    "      match_high ? static_cast<unsigned>(value & 0xffffffff) : 0u;\n"
    "  unsigned low = __reduce_max_sync(0xffffffff, sel_low);\n"
    "  return low + (static_cast<unsigned long long>(high) << 32);\n"
    "}\n"
    "#endif  // __CUDA_ARCH__ >= 800\n"
    "\n"
    "// specialization for T=unsigned\n"
    "template <>\n"
    "__device__ __forceinline__ int\n"
    "CudaSpecialization::count_leading_zeros<unsigned>(unsigned value) {\n"
    "  return __clz(value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned long long\n"
    "template <>\n"
    "__device__ __forceinline__ int\n"
    "CudaSpecialization::count_leading_zeros<unsigned long long>(\n"
    "    unsigned long long value) {\n"
    "  return __clzll(value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned\n"
    "template <>\n"
    "__device__ __forceinline__ int CudaSpecialization::population_count<unsigned>(\n"
    "    unsigned value) {\n"
    "  return __popc(value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned long long\n"
    "template <>\n"
    "__device__ __forceinline__ int\n"
    "CudaSpecialization::population_count<unsigned long long>(\n"
    "    unsigned long long value) {\n"
    "  return __popcll(value);\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned long long\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned long long CudaSpecialization::extract_bits(\n"
    "    unsigned long long value, int start_bit, int num_bits) {\n"
    "  unsigned long long mask = (1llu << num_bits) - 1;\n"
    "  return (value >> start_bit) & mask;\n"
    "}\n"
    "\n"
    "#define Q(x) #x\n"
    "#define QUOTE(x) Q(x)\n"
    "#include QUOTE(CUDA_PLATFORM_SPECIALIZATION_HEADER)\n"
    "#undef QUOTE\n"
    "#undef Q\n";
bool breeze_platforms_cuda_cuh_reg = registerHeader(breeze_platforms_cuda_cuh);
const char* breeze_platforms_specialization_cuda_ptx_cuh =
    "breeze/platforms/specialization/cuda-ptx.cuh\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "/*\n"
    " * Copyright (c) 2024 by Rivos Inc.\n"
    " * Licensed under the Apache License, Version 2.0, see LICENSE for details.\n"
    " * SPDX-License-Identifier: Apache-2.0\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "template <int WARP_THREADS>\n"
    "__device__ __forceinline__ int CudaSpecialization::lane_idx() {\n"
    "  return threadIdx.x % WARP_THREADS;\n"
    "}\n"
    "\n"
    "template <int WARP_THREADS>\n"
    "__device__ __forceinline__ int CudaSpecialization::warp_idx() {\n"
    "  return threadIdx.x / WARP_THREADS;\n"
    "}\n"
    "\n"
    "__device__ __forceinline__ unsigned LANEMASK_LT() {\n"
    "  unsigned result;\n"
    "  asm(\"mov.u32 %0, %%lanemask_lt;\" : \"=r\"(result));\n"
    "  return result;\n"
    "}\n"
    "\n"
    "template <int WARP_THREADS>\n"
    "__device__ __forceinline__ unsigned CudaSpecialization::lower_rank_lanemask() {\n"
    "  return LANEMASK_LT();\n"
    "}\n"
    "\n"
    "__device__ __forceinline__ unsigned LANEMASK_GT() {\n"
    "  unsigned result;\n"
    "  asm(\"mov.u32 %0, %%lanemask_gt;\" : \"=r\"(result));\n"
    "  return result;\n"
    "}\n"
    "\n"
    "template <int WARP_THREADS>\n"
    "__device__ __forceinline__ unsigned CudaSpecialization::higher_rank_lanemask() {\n"
    "  return LANEMASK_GT();\n"
    "}\n"
    "\n"
    "__device__ __forceinline__ unsigned BFE(unsigned value, int start_bit,\n"
    "                                        int num_bits) {\n"
    "  asm(\"bfe.u32 %0, %1, %2, %3;\"\n"
    "      : \"=r\"(value)\n"
    "      : \"r\"(value), \"r\"(start_bit), \"r\"(num_bits));\n"
    "  return value;\n"
    "}\n"
    "\n"
    "// specialization for T=unsigned\n"
    "template <>\n"
    "__device__ __forceinline__ unsigned CudaSpecialization::extract_bits(\n"
    "    unsigned value, int start_bit, int num_bits) {\n"
    "  return BFE(value, start_bit, num_bits);\n"
    "}\n"
    "\n"
    "__device__ __forceinline__ void PREFETCH(void *ptr) {\n"
    "  asm(\"prefetch.global.L1 [%0];\" ::\"l\"(ptr));\n"
    "}\n"
    "\n"
    "// specialization for T=Slice<GLOBAL, BLOCKED, unsigned long long>\n"
    "template <>\n"
    "__device__ __forceinline__ void\n"
    "CudaSpecialization::prefetch<breeze::utils::Slice<\n"
    "    breeze::utils::GLOBAL, breeze::utils::BLOCKED, unsigned long long>>(\n"
    "    breeze::utils::Slice<breeze::utils::GLOBAL, breeze::utils::BLOCKED,\n"
    "                         unsigned long long>\n"
    "        address) {\n"
    "  PREFETCH(address.data());\n"
    "}\n";
bool breeze_platforms_specialization_cuda_ptx_cuh_reg =
    registerHeader(breeze_platforms_specialization_cuda_ptx_cuh);
const char* breeze_platforms_platform_h =
    "breeze/platforms/platform.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "/*\n"
    " * Copyright (c) 2024 by Rivos Inc.\n"
    " * Licensed under the Apache License, Version 2.0, see LICENSE for details.\n"
    " * SPDX-License-Identifier: Apache-2.0\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "// some platforms require a system header to be included\n"
    "#if defined(PLATFORM_CUDA)\n"
    "#include <cuda_runtime_api.h>\n"
    "#elif defined(PLATFORM_HIP)\n"
    "#include <hip/hip_runtime.h>\n"
    "#elif defined(PLATFORM_METAL)\n"
    "#include <metal_stdlib>\n"
    "#endif\n"
    "\n"
    "// function attributes used for device code\n"
    "#if defined(PLATFORM_CUDA) || defined(PLATFORM_HIP)\n"
    "#define ATTR __device__ __forceinline__\n"
    "#elif defined(PLATFORM_SYCL) || defined(PLATFORM_METAL) || \\\n"
    "    defined(PLATFORM_OPENCL) || defined(PLATFORM_OPENMP)\n"
    "#define ATTR inline\n"
    "#else\n"
    "#error platform must be set\n"
    "#endif\n";
bool breeze_platforms_platform_h_reg =
    registerHeader(breeze_platforms_platform_h);
const char* breeze_utils_types_h =
    "breeze/utils/types.h\n"
    "/*\n"
    " * Copyright (c) Facebook, Inc. and its affiliates.\n"
    " *\n"
    " * Licensed under the Apache License, Version 2.0 (the \"License\");\n"
    " * you may not use this file except in compliance with the License.\n"
    " * You may obtain a copy of the License at\n"
    " *\n"
    " *     http://www.apache.org/licenses/LICENSE-2.0\n"
    " *\n"
    " * Unless required by applicable law or agreed to in writing, software\n"
    " * distributed under the License is distributed on an \"AS IS\" BASIS,\n"
    " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
    " * See the License for the specific language governing permissions and\n"
    " * limitations under the License.\n"
    " */\n"
    "\n"
    "/*\n"
    " * Copyright (c) 2024 by Rivos Inc.\n"
    " * Licensed under the Apache License, Version 2.0, see LICENSE for details.\n"
    " * SPDX-License-Identifier: Apache-2.0\n"
    " */\n"
    "\n"
    "#pragma once\n"
    "\n"
    "#include \"breeze/platforms/platform.h\"\n"
    "\n"
    "#if !defined(__CUDACC_RTC__) && defined(__EXCEPTIONS)\n"
    "#include <exception>\n"
    "#include <string>\n"
    "#endif\n"
    "\n"
    "namespace breeze {\n"
    "namespace utils {\n"
    "\n"
    "using size_type = int;\n"
    "\n"
    "template <int N, int M>\n"
    "struct Min {\n"
    "  enum { VALUE = ((N < M) ? N : M) };\n"
    "};\n"
    "\n"
    "template <int N, int M>\n"
    "struct Max {\n"
    "  enum { VALUE = ((N > M) ? N : M) };\n"
    "};\n"
    "\n"
    "template <int N>\n"
    "struct PowerOfTwo {\n"
    "  enum { VALUE = ((N & (N - 1)) == 0) };\n"
    "};\n"
    "\n"
    "template <int N, int D>\n"
    "struct DivideAndRoundUp {\n"
    "  enum { VALUE = N / D + (N % D != 0 ? 1 : 0) };\n"
    "};\n"
    "\n"
    "template <int N, int D>\n"
    "struct RoundUp {\n"
    "  enum {\n"
    "    DIVIDE_AND_ROUND_UP_VALUE = DivideAndRoundUp<N, D>::VALUE,\n"
    "    VALUE = DIVIDE_AND_ROUND_UP_VALUE * D,\n"
    "  };\n"
    "};\n"
    "\n"
    "template <typename T>\n"
    "struct NumericLimits {\n"
    "  static ATTR T min();\n"
    "  static ATTR T max();\n"
    "};\n"
    "\n"
    "// specialization for T=int\n"
    "template <>\n"
    "struct NumericLimits<int> {\n"
    "  static ATTR int min() { return -2147483648; }\n"
    "  static ATTR int max() { return 2147483647; }\n"
    "};\n"
    "\n"
    "// specialization for T=unsigned\n"
    "template <>\n"
    "struct NumericLimits<unsigned> {\n"
    "  static ATTR unsigned min() { return 0u; }\n"
    "  static ATTR unsigned max() { return 4294967295u; }\n"
    "};\n"
    "\n"
    "// specialization for T=float\n"
    "template <>\n"
    "struct NumericLimits<float> {\n"
    "  static ATTR float min() { return 1.17549435e-38f; }\n"
    "  static ATTR float max() { return 3.40282347e+38f; }\n"
    "};\n"
    "\n"
    "#if !defined(PLATFORM_METAL)\n"
    "\n"
    "// specialization for T=long long\n"
    "template <>\n"
    "struct NumericLimits<long long> {\n"
    "  static ATTR long long min() { return -9223372036854775807ll; }\n"
    "  static ATTR long long max() { return 9223372036854775807ll; }\n"
    "};\n"
    "\n"
    "// specialization for T=unsigned long long\n"
    "template <>\n"
    "struct NumericLimits<unsigned long long> {\n"
    "  static ATTR unsigned long long min() { return 0; }\n"
    "  static ATTR unsigned long long max() { return 18446744073709551615llu; }\n"
    "};\n"
    "\n"
    "// specialization for T=double\n"
    "template <>\n"
    "struct NumericLimits<double> {\n"
    "  static ATTR double min() { return 2.22507385850720138309e-308L; }\n"
    "  static ATTR double max() { return 1.79769313486231570815e+308L; }\n"
    "};\n"
    "\n"
    "#endif  // !defined(PLATFORM_METAL)\n"
    "\n"
    "template <typename T>\n"
    "struct Msb {\n"
    "  enum { VALUE = sizeof(T) * /*CHAR_BIT=*/8 - 1 };\n"
    "};\n"
    "\n"
    "#ifndef __CUDACC_RTC__\n"
    "\n"
    "template <typename T>\n"
    "T NextPowerOfTwo(T value);\n"
    "\n"
    "// specialization for T=int\n"
    "template <>\n"
    "inline int NextPowerOfTwo(int value) {\n"
    "  return value == 1 ? 1 : 1 << (32 - __builtin_clz(value - 1));\n"
    "}\n"
    "\n"
    "#endif\n"
    "\n"
    "template <typename T, typename U>\n"
    "struct IsSame {\n"
    "  enum { VALUE = 0 };\n"
    "};\n"
    "\n"
    "template <typename T>\n"
    "struct IsSame<T, T> {\n"
    "  enum { VALUE = 1 };\n"
    "};\n"
    "\n"
    "template <typename T, typename U>\n"
    "struct IsDifferent {\n"
    "  enum { VALUE = !IsSame<T, U>::VALUE };\n"
    "};\n"
    "\n"
    "template <typename T>\n"
    "struct RemoveConstT {\n"
    "  using Type = T;\n"
    "};\n"
    "\n"
    "template <typename T>\n"
    "struct RemoveConstT<const T> {\n"
    "  using Type = T;\n"
    "};\n"
    "\n"
    "template <typename T>\n"
    "using RemoveConst = typename RemoveConstT<T>::Type;\n"
    "\n"
    "#if !defined(__CUDACC_RTC__) && defined(__EXCEPTIONS)\n"
    "\n"
    "// custom exception used for device allocation failures\n"
    "class BadDeviceAlloc : public std::exception {\n"
    " public:\n"
    "  BadDeviceAlloc(size_t size, size_t free, size_t total)\n"
    "      : message_(\"BadDeviceAlloc(size=\" + std::to_string(size) +\n"
    "                 \",free=\" + std::to_string(free) +\n"
    "                 \",total=\" + std::to_string(total) + \")\") {}\n"
    "\n"
    "  virtual const char *what() const noexcept { return message_.c_str(); }\n"
    "\n"
    " private:\n"
    "  std::string message_;\n"
    "};\n"
    "\n"
    "#endif  // !__CUDACC_RTC__ && __EXCEPTIONS\n"
    "\n"
    "enum AddressSpace {\n"
    "  THREAD,\n"
    "  SHARED,\n"
    "  GLOBAL,\n"
    "};\n"
    "\n"
    "enum DataArrangement {\n"
    "  BLOCKED,\n"
    "  STRIPED,\n"
    "  WARP_STRIPED,\n"
    "};\n"
    "\n"
    "enum MemoryOrder {\n"
    "  RELAXED,\n"
    "  ACQUIRE,\n"
    "  RELEASE,\n"
    "};\n"
    "\n"
    "class NullType {};\n"
    "\n"
    "class EmptySlice {\n"
    "  using data_type = NullType;\n"
    "};\n"
    "\n"
    "ATTR EmptySlice constexpr make_empty_slice() { return EmptySlice{}; }\n"
    "\n"
    "// metal platform requires specialization as it has native address space\n"
    "// qualifiers\n"
    "#if defined(PLATFORM_METAL)\n"
    "\n"
    "template <AddressSpace A, DataArrangement B, typename T>\n"
    "class Slice {};\n"
    "\n"
    "// partial specialization for ADDRESS_SPACE=THREAD\n"
    "template <DataArrangement A, typename T>\n"
    "class Slice<THREAD, A, T> {\n"
    " public:\n"
    "  static constant AddressSpace ADDRESS_SPACE = THREAD;\n"
    "  static constant DataArrangement ARRANGEMENT = A;\n"
    "  using data_type = T;\n"
    "\n"
    "  ATTR explicit Slice(thread T *data) : data_(data) {}\n"
    "  ATTR T const thread &operator[](int index) const { return data_[index]; }\n"
    "  ATTR T thread &operator[](int index) { return data_[index]; }\n"
    "  ATTR T const thread &operator*() const { return *data_; }\n"
    "  ATTR T thread &operator*() { return *data_; }\n"
    "  ATTR T const thread *operator->() const { return data_; }\n"
    "  ATTR T thread *operator->() { return data_; }\n"
    "  ATTR Slice<THREAD, A, T> subslice(int offset) {\n"
    "    return Slice<THREAD, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR const Slice<THREAD, A, T> subslice(int offset) const {\n"
    "    return Slice<THREAD, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR operator bool() const { return data_ != nullptr; }\n"
    "  ATTR T thread *data() { return data_; }\n"
    "  ATTR T const thread *data() const { return data_; }\n"
    "\n"
    " private:\n"
    "  thread T *data_;\n"
    "};\n"
    "\n"
    "template <AddressSpace A = THREAD, DataArrangement B = STRIPED, typename T>\n"
    "ATTR Slice<THREAD, B, T> constexpr make_slice(thread T *data) {\n"
    "  return Slice<THREAD, B, T>(data);\n"
    "}\n"
    "\n"
    "// partial specialization for ADDRESS_SPACE=SHARED\n"
    "template <DataArrangement A, typename T>\n"
    "class Slice<SHARED, A, T> {\n"
    " public:\n"
    "  static constant AddressSpace ADDRESS_SPACE = SHARED;\n"
    "  static constant DataArrangement ARRANGEMENT = A;\n"
    "  using data_type = T;\n"
    "\n"
    "  ATTR explicit Slice(threadgroup T *data) : data_(data) {}\n"
    "  ATTR T const threadgroup &operator[](int index) const { return data_[index]; }\n"
    "  ATTR T threadgroup &operator[](int index) { return data_[index]; }\n"
    "  ATTR T const threadgroup &operator*() const { return *data_; }\n"
    "  ATTR T threadgroup &operator*() { return *data_; }\n"
    "  ATTR T const threadgroup *operator->() const { return data_; }\n"
    "  ATTR T threadgroup *operator->() { return data_; }\n"
    "  ATTR Slice<SHARED, A, T> subslice(int offset) {\n"
    "    return Slice<SHARED, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR const Slice<SHARED, A, T> subslice(int offset) const {\n"
    "    return Slice<SHARED, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR operator bool() const { return data_ != nullptr; }\n"
    "  ATTR T threadgroup *data() { return data_; }\n"
    "  ATTR T const threadgroup *data() const { return data_; }\n"
    "\n"
    "  template <AddressSpace OTHER>\n"
    "  ATTR Slice<OTHER, A, T> reinterpret();\n"
    "  // This is a no-op, but is needed for consistency with other platforms\n"
    "  template <>\n"
    "  ATTR Slice<SHARED, A, T> reinterpret<SHARED>() {\n"
    "    return *this;\n"
    "  }\n"
    "\n"
    " private:\n"
    "  threadgroup T *data_;\n"
    "};\n"
    "\n"
    "template <AddressSpace A = SHARED, DataArrangement B = BLOCKED, typename T>\n"
    "ATTR Slice<SHARED, B, T> constexpr make_slice(threadgroup T *data) {\n"
    "  return Slice<SHARED, B, T>(data);\n"
    "}\n"
    "\n"
    "// partial specialization for ADDRESS_SPACE=GLOBAL\n"
    "template <DataArrangement A, typename T>\n"
    "class Slice<GLOBAL, A, T> {\n"
    " public:\n"
    "  static constant AddressSpace ADDRESS_SPACE = GLOBAL;\n"
    "  static constant DataArrangement ARRANGEMENT = A;\n"
    "  using data_type = T;\n"
    "\n"
    "  ATTR explicit Slice(device T *data) : data_(data) {}\n"
    "  ATTR T const device &operator[](int index) const { return data_[index]; }\n"
    "  ATTR T device &operator[](int index) { return data_[index]; }\n"
    "  ATTR T const device &operator*() const { return *data_; }\n"
    "  ATTR T device &operator*() { return *data_; }\n"
    "  ATTR Slice<GLOBAL, A, T> subslice(int offset) {\n"
    "    return Slice<GLOBAL, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR const Slice<GLOBAL, A, T> subslice(int offset) const {\n"
    "    return Slice<GLOBAL, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR operator bool() const { return data_ != nullptr; }\n"
    "  ATTR T device *data() { return data_; }\n"
    "  ATTR T const device *data() const { return data_; }\n"
    "\n"
    " private:\n"
    "  device T *data_;\n"
    "};\n"
    "\n"
    "template <AddressSpace A = GLOBAL, DataArrangement B = BLOCKED, typename T>\n"
    "ATTR Slice<GLOBAL, B, T> constexpr make_slice(device T *data) {\n"
    "  return Slice<GLOBAL, B, T>(data);\n"
    "}\n"
    "\n"
    "#elif defined(PLATFORM_OPENCL)\n"
    "\n"
    "template <AddressSpace A, DataArrangement B, typename T>\n"
    "class Slice {};\n"
    "\n"
    "// partial specialization for ADDRESS_SPACE=THREAD\n"
    "template <DataArrangement A, typename T>\n"
    "class Slice<THREAD, A, T> {\n"
    " public:\n"
    "  static constexpr AddressSpace ADDRESS_SPACE = THREAD;\n"
    "  static constexpr DataArrangement ARRANGEMENT = A;\n"
    "  using data_type = T;\n"
    "\n"
    "  ATTR explicit Slice(__private T *data) : data_(data) {}\n"
    "  ATTR T const __private &operator[](int index) const { return data_[index]; }\n"
    "  ATTR T __private &operator[](int index) { return data_[index]; }\n"
    "  ATTR T const __private &operator*() const { return *data_; }\n"
    "  ATTR T __private &operator*() { return *data_; }\n"
    "  ATTR T const __private *operator->() const { return data_; }\n"
    "  ATTR T __private *operator->() { return data_; }\n"
    "  ATTR Slice<THREAD, A, T> subslice(int offset) {\n"
    "    return Slice<THREAD, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR const Slice<THREAD, A, T> subslice(int offset) const {\n"
    "    return Slice<THREAD, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR operator bool() const { return data_ != nullptr; }\n"
    "  ATTR T __private *data() { return data_; }\n"
    "  ATTR T const __private *data() const { return data_; }\n"
    "\n"
    "  template <AddressSpace OTHER>\n"
    "  ATTR Slice<OTHER, A, T> reinterpret();\n"
    "  // This is a no-op, but is needed for consistency with other platforms\n"
    "  template <>\n"
    "  ATTR Slice<THREAD, A, T> reinterpret<THREAD>() {\n"
    "    return *this;\n"
    "  }\n"
    "\n"
    " private:\n"
    "  __private T *data_;\n"
    "};\n"
    "\n"
    "template <AddressSpace A = THREAD, DataArrangement B = STRIPED, typename T>\n"
    "ATTR Slice<THREAD, B, T> constexpr make_slice(private T *data) {\n"
    "  return Slice<THREAD, B, T>(data);\n"
    "}\n"
    "\n"
    "// partial specialization for ADDRESS_SPACE=SHARED\n"
    "template <DataArrangement A, typename T>\n"
    "class Slice<SHARED, A, T> {\n"
    " public:\n"
    "  static constexpr AddressSpace ADDRESS_SPACE = SHARED;\n"
    "  static constexpr DataArrangement ARRANGEMENT = A;\n"
    "  using data_type = T;\n"
    "\n"
    "  ATTR explicit Slice(local T *data) : data_(data) {}\n"
    "  ATTR T const local &operator[](int index) const { return data_[index]; }\n"
    "  ATTR T local &operator[](int index) { return data_[index]; }\n"
    "  ATTR T const local &operator*() const { return *data_; }\n"
    "  ATTR T local &operator*() { return *data_; }\n"
    "  ATTR T const local *operator->() const { return data_; }\n"
    "  ATTR T local *operator->() { return data_; }\n"
    "  ATTR Slice<SHARED, A, T> subslice(int offset) {\n"
    "    return Slice<SHARED, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR const Slice<SHARED, A, T> subslice(int offset) const {\n"
    "    return Slice<SHARED, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR operator bool() const { return data_ != nullptr; }\n"
    "  ATTR T local *data() { return data_; }\n"
    "  ATTR T const local *data() const { return data_; }\n"
    "\n"
    "  template <AddressSpace OTHER>\n"
    "  ATTR Slice<OTHER, A, T> reinterpret();\n"
    "  // This is a no-op, but is needed for consistency with other platforms\n"
    "  template <>\n"
    "  ATTR Slice<SHARED, A, T> reinterpret<SHARED>() {\n"
    "    return *this;\n"
    "  }\n"
    "\n"
    " private:\n"
    "  local T *data_;\n"
    "};\n"
    "\n"
    "template <AddressSpace A = SHARED, DataArrangement B = BLOCKED, typename T>\n"
    "ATTR Slice<SHARED, B, T> constexpr make_slice(local T *data) {\n"
    "  return Slice<SHARED, B, T>(data);\n"
    "}\n"
    "\n"
    "// partial specialization for ADDRESS_SPACE=GLOBAL\n"
    "template <DataArrangement A, typename T>\n"
    "class Slice<GLOBAL, A, T> {\n"
    " public:\n"
    "  static constexpr AddressSpace ADDRESS_SPACE = GLOBAL;\n"
    "  static constexpr DataArrangement ARRANGEMENT = A;\n"
    "  using data_type = T;\n"
    "\n"
    "  ATTR explicit Slice(global T *data) : data_(data) {}\n"
    "  ATTR T const global &operator[](int index) const { return data_[index]; }\n"
    "  ATTR T global &operator[](int index) { return data_[index]; }\n"
    "  ATTR T const global &operator*() const { return *data_; }\n"
    "  ATTR T global &operator*() { return *data_; }\n"
    "  ATTR Slice<GLOBAL, A, T> subslice(int offset) {\n"
    "    return Slice<GLOBAL, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR const Slice<GLOBAL, A, T> subslice(int offset) const {\n"
    "    return Slice<GLOBAL, A, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR operator bool() const { return data_ != nullptr; }\n"
    "  ATTR T global *data() { return data_; }\n"
    "  ATTR T const global *data() const { return data_; }\n"
    "\n"
    " private:\n"
    "  global T *data_;\n"
    "};\n"
    "\n"
    "template <AddressSpace A = GLOBAL, DataArrangement B = BLOCKED, typename T>\n"
    "ATTR Slice<GLOBAL, B, T> constexpr make_slice(global T *data) {\n"
    "  return Slice<GLOBAL, B, T>(data);\n"
    "}\n"
    "\n"
    "#else  // defined(PLATFORM_OPENCL)\n"
    "\n"
    "template <AddressSpace A, DataArrangement B, typename T>\n"
    "class Slice {\n"
    " public:\n"
    "  static constexpr AddressSpace ADDRESS_SPACE = A;\n"
    "  static constexpr DataArrangement ARRANGEMENT = B;\n"
    "  using data_type = T;\n"
    "  using pointer_type = T *;\n"
    "\n"
    "  ATTR explicit Slice(pointer_type data) : data_(data) {}\n"
    "  ATTR T const &operator[](int index) const { return data_[index]; }\n"
    "  ATTR T &operator[](int index) { return data_[index]; }\n"
    "  ATTR T const &operator*() const { return *data_; }\n"
    "  ATTR T &operator*() { return *data_; }\n"
    "  ATTR T const *operator->() const { return data_; }\n"
    "  ATTR T *operator->() { return data_; }\n"
    "  ATTR Slice<A, B, T> subslice(int offset) {\n"
    "    return Slice<A, B, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR const Slice<A, B, T> subslice(int offset) const {\n"
    "    return Slice<A, B, T>(data_ + offset);\n"
    "  }\n"
    "  ATTR operator bool() const { return data_ != nullptr; }\n"
    "  ATTR T *data() { return data_; }\n"
    "  ATTR T const *data() const { return data_; }\n"
    "\n"
    "  template <AddressSpace OTHER, DataArrangement OTHER_ARRANGEMENT =\n"
    "                                    OTHER == THREAD ? STRIPED : BLOCKED>\n"
    "  ATTR Slice<OTHER, OTHER_ARRANGEMENT, T> reinterpret() {\n"
    "    using other_slice_type = Slice<OTHER, OTHER_ARRANGEMENT, T>;\n"
    "    return other_slice_type((typename other_slice_type::pointer_type)data_);\n"
    "  }\n"
    "\n"
    " private:\n"
    "  pointer_type data_;\n"
    "};\n"
    "\n"
    "// STRIPED arrangement by default for THREAD address space and BLOCKED\n"
    "// arrangement by default for other address spaces\n"
    "template <AddressSpace A = THREAD,\n"
    "          DataArrangement B = A == THREAD ? STRIPED : BLOCKED, typename T>\n"
    "ATTR Slice<A, B, T> constexpr make_slice(T *data) {\n"
    "  return Slice<A, B, T>(data);\n"
    "}\n"
    "\n"
    "#endif  // !defined(PLATFORM_METAL) && !defined(PLATFORM_OPENCL)\n"
    "\n"
    "}  // namespace utils\n"
    "}  // namespace breeze\n";
bool breeze_utils_types_h_reg = registerHeader(breeze_utils_types_h);
} // namespace facebook::velox::wave
